diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java new file mode 100644 index 0000000000000..a74759f0ce620 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.sources.Aggregation; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down aggregates to the data source. + * + * @since 3.1.0 + */ +@Evolving +public interface SupportsPushDownAggregates extends ScanBuilder { + + /** + * Pushes down Aggregation to datasource. + * The Aggregation can be pushed down only if all the Aggregate Functions can + * be pushed down. + */ + void pushAggregation(Aggregation aggregation); + + /** + * Returns the aggregates that are pushed to the data source via + * {@link #pushAggregation(Aggregation aggregation)}. + */ + Aggregation pushedAggregation(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala new file mode 100644 index 0000000000000..20cb93301cd85 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/aggregates.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +case class Aggregation(aggregateExpressions: Seq[AggregateFunc], + groupByExpressions: Seq[String]) + +abstract class AggregateFunc + +// Todo: add Count + +case class Avg(column: String, isDistinct: Boolean) extends AggregateFunc +case class Min(column: String) extends AggregateFunc +case class Max(column: String) extends AggregateFunc +case class Sum(column: String, isDistinct: Boolean) extends AggregateFunc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index df3b9f2a4e9cb..d84bb9482dd8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{BaseRelation, Filter} +import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -102,6 +102,7 @@ case class RowDataSourceScanExec( requiredSchema: StructType, filters: Set[Filter], handledFilters: Set[Filter], + aggregation: Aggregation, rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -132,9 +133,17 @@ case class RowDataSourceScanExec( val markedFilters = for (filter <- filters) yield { if (handledFilters.contains(filter)) s"*$filter" else s"$filter" } + val markedAggregates = for (aggregate <- aggregation.aggregateExpressions) yield { + s"*$aggregate" + } + val markedGroupby = for (groupby <- aggregation.groupByExpressions) yield { + s"*$groupby" + } Map( "ReadSchema" -> requiredSchema.catalogString, - "PushedFilters" -> markedFilters.mkString("[", ", ", "]")) + "PushedFilters" -> markedFilters.mkString("[", ", ", "]"), + "PushedAggregates" -> markedAggregates.mkString("[", ", ", "]"), + "PushedGroupby" -> markedGroupby.mkString("[", ", ", "]")) } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a097017222b57..053c92a70e642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{CacheTable, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project, UncacheTable} import org.apache.spark.sql.catalyst.rules.Rule @@ -358,6 +359,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, + Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -431,6 +433,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, + Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -453,6 +456,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, + Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -700,6 +704,101 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } + protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { + + def columnAsString(e: Expression): String = e match { + case AttributeReference(name, _, _, _) => name + case Cast(child, _, _) => child match { + case AttributeReference(name, _, _, _) => name + case Add(left, right, _) => + arithmeticExpressionAsString(left, right, "+") + case Subtract(left, right, _) => + arithmeticExpressionAsString(left, right, "-") + case Multiply(left, right, _) => + arithmeticExpressionAsString(left, right, "*") + case Divide(left, right, _) => + arithmeticExpressionAsString(left, right, "/") + case _ => "" + } + case Add(left, right, _) => + arithmeticExpressionAsString(left, right, "+") + case Subtract(left, right, _) => + arithmeticExpressionAsString(left, right, "-") + case Multiply(left, right, _) => + arithmeticExpressionAsString(left, right, "*") + case Divide(left, right, _) => + arithmeticExpressionAsString(left, right, "/") + case CheckOverflow(child, _, _) => child match { + case Add(left, right, _) => + arithmeticExpressionAsString(left, right, "+") + case Subtract(left, right, _) => + arithmeticExpressionAsString(left, right, "-") + case Multiply(left, right, _) => + arithmeticExpressionAsString(left, right, "*") + case Divide(left, right, _) => + arithmeticExpressionAsString(left, right, "/") + case _ => "" + } + case _ => "" + } + + aggregates.aggregateFunction match { + case aggregate.Min(child) => + val columnName = columnAsString(child) + if (!columnName.isEmpty) Some(Min(columnName)) else None + case aggregate.Max(child) => + val columnName = columnAsString(child) + if (!columnName.isEmpty) Some(Max(columnName)) else None + case aggregate.Average(child) => + val columnName = columnAsString(child) + if (!columnName.isEmpty) Some(Avg(columnName, aggregates.isDistinct)) else None + case aggregate.Sum(child) => + val columnName = columnAsString(child) + if (!columnName.isEmpty) Some(Sum(columnName, aggregates.isDistinct)) else None + case _ => None + } + } + + private def arithmeticExpressionAsString ( + left: Expression, + right: Expression, + sign: String): String = { + + val leftName = if (left.isInstanceOf[AttributeReference]) { + left.asInstanceOf[AttributeReference].name + } else if (left.isInstanceOf[Cast]) { + if (left.asInstanceOf[Cast].child.isInstanceOf[AttributeReference]) { + left.asInstanceOf[Cast].child.asInstanceOf[AttributeReference].name + } + } else if (left.isInstanceOf[PromotePrecision]) { + if (left.asInstanceOf[PromotePrecision].child.isInstanceOf[AttributeReference]) { + left.asInstanceOf[PromotePrecision].child.asInstanceOf[AttributeReference].name + } else if (left.asInstanceOf[PromotePrecision].child.isInstanceOf[Cast]) { + if (left.asInstanceOf[PromotePrecision].child.isInstanceOf[Cast]) { + left.asInstanceOf[PromotePrecision].child.asInstanceOf[Cast] + .child.asInstanceOf[AttributeReference].name + } + } + } + val rightName = if (right.isInstanceOf[AttributeReference]) { + right.asInstanceOf[AttributeReference].name + } else if (right.isInstanceOf[Cast]) { + if (right.asInstanceOf[Cast].child.isInstanceOf[AttributeReference]) { + right.asInstanceOf[Cast].child.asInstanceOf[AttributeReference].name + } + } else if (right.isInstanceOf[PromotePrecision]) { + if (right.asInstanceOf[PromotePrecision].child.isInstanceOf[AttributeReference]) { + right.asInstanceOf[PromotePrecision].child.asInstanceOf[AttributeReference].name + } else if (right.asInstanceOf[PromotePrecision].child.isInstanceOf[Cast]) { + if (right.asInstanceOf[PromotePrecision].child.isInstanceOf[Cast]) { + right.asInstanceOf[PromotePrecision].child.asInstanceOf[Cast] + .child.asInstanceOf[AttributeReference].name + } + } + } + s"$leftName $sign $rightName" + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 6e8b7ea678264..8599a0bbc74c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -191,6 +191,9 @@ class JDBCOptions( // An option to allow/disallow pushing down predicate into JDBC data source val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean + // An option to allow/disallow pushing down aggregate into JDBC data source + val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "true").toBoolean + // The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either // by --files option of spark-submit or manually val keytab = { @@ -260,6 +263,7 @@ object JDBCOptions { val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") + val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") val JDBC_TABLE_COMMENT = newOption("tableComment") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 87ca78db59b29..ea968748b58a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, PreparedStatement, ResultSet} +import java.util.StringTokenizer +import scala.collection.mutable.ArrayBuilder import scala.util.control.NonFatal import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} @@ -133,6 +135,59 @@ object JDBCRDD extends Logging { }) } + def compileAggregates( + aggregates: Seq[AggregateFunc], + dialect: JdbcDialect): (Array[String]) = { + def quote(colName: String): String = dialect.quoteIdentifier(colName) + val aggBuilder = ArrayBuilder.make[String] + aggregates.map { + case Min(column) => + if (!column.contains("+") && !column.contains("-") && !column.contains("*") + && !column.contains("/")) { + aggBuilder += s"MIN(${quote(column)})" + } else { + aggBuilder += s"MIN(${quoteEachCols(column, dialect)})" + } + case Max(column) => + if (!column.contains("+") && !column.contains("-") && !column.contains("*") + && !column.contains("/")) { + aggBuilder += s"MAX(${quote(column)})" + } else { + aggBuilder += s"MAX(${quoteEachCols(column, dialect)})" + } + case Sum(column, isDistinct) => + val distinct = if (isDistinct) "DISTINCT " else "" + if (!column.contains("+") && !column.contains("-") && !column.contains("*") + && !column.contains("/")) { + aggBuilder += s"SUM(${distinct}${quote(column)})" + } else { + aggBuilder += s"SUM(${distinct}${quoteEachCols(column, dialect)})" + } + case Avg(column, isDistinct) => + val distinct = if (isDistinct) "DISTINCT " else "" + if (!column.contains("+") && !column.contains("-") && !column.contains("*") + && !column.contains("/")) { + aggBuilder += s"AVG(${distinct}${quote(column)})" + } else { + aggBuilder += s"AVG(${distinct}${quoteEachCols(column, dialect)})" + } + case _ => + } + aggBuilder.result + } + + private def quoteEachCols (column: String, dialect: JdbcDialect): String = { + def quote(colName: String): String = dialect.quoteIdentifier(colName) + val colsBuilder = ArrayBuilder.make[String] + val st = new StringTokenizer(column, "+-*/", true) + colsBuilder += quote(st.nextToken().trim) + while (st.hasMoreTokens) { + colsBuilder += st.nextToken + colsBuilder += quote(st.nextToken().trim) + } + colsBuilder.result.mkString(" ") + } + /** * Build and return JDBCRDD from the given information. * @@ -152,7 +207,9 @@ object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition], - options: JDBCOptions): RDD[InternalRow] = { + options: JDBCOptions, + aggregation: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])) + : RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) @@ -164,7 +221,8 @@ object JDBCRDD extends Logging { filters, parts, url, - options) + options, + aggregation) } } @@ -181,7 +239,8 @@ private[jdbc] class JDBCRDD( filters: Array[Filter], partitions: Array[Partition], url: String, - options: JDBCOptions) + options: JDBCOptions, + aggregation: Aggregation = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])) extends RDD[InternalRow](sc, Nil) { /** @@ -189,13 +248,127 @@ private[jdbc] class JDBCRDD( */ override def getPartitions: Array[Partition] = partitions + private var updatedSchema: StructType = new StructType() + /** * `columns`, but as a String suitable for injection into a SQL query. */ private val columnList: String = { + val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, + JdbcDialects.get(url)) val sb = new StringBuilder() - columns.foreach(x => sb.append(",").append(x)) - if (sb.isEmpty) "1" else sb.substring(1) + if (compiledAgg.length == 0) { + updatedSchema = schema + columns.foreach(x => sb.append(",").append(x)) + } else { + getAggregateColumnsList(sb, compiledAgg) + } + if (sb.length == 0) "1" else sb.substring(1) + } + + private def getAggregateColumnsList(sb: StringBuilder, compiledAgg: Array[String]) = { + val colDataTypeMap: Map[String, StructField] = columns.zip(schema.fields).toMap + val newColsBuilder = ArrayBuilder.make[String] + val newColsBuilderSchema = ArrayBuilder.make[String] + for (col <- compiledAgg) { + newColsBuilder += col + newColsBuilderSchema += col.replace("DISTINCT ", "") + } + for (groupBy <- aggregation.groupByExpressions) { + newColsBuilder += JdbcDialects.get(url).quoteIdentifier(groupBy) + newColsBuilderSchema += JdbcDialects.get(url).quoteIdentifier(groupBy) + } + val newColumns = newColsBuilder.result + sb.append(", ").append(newColumns.mkString(", ")) + + val newColumnsSchema = newColsBuilderSchema.result + // build new schemas + for (c <- newColumnsSchema) { + val colName: Array[String] = if (!c.contains("+") && !c.contains("-") && !c.contains("*") + && !c.contains("/")) { + if (c.contains("MAX") || c.contains("MIN") || c.contains("SUM") || c.contains("AVG")) { + Array(c.substring(c.indexOf("(") + 1, c.indexOf(")"))) + } else { + Array(c) + } + } else { + val colsBuilder = ArrayBuilder.make[String] + val st = new StringTokenizer(c.substring(c.indexOf("(") + 1, c.indexOf(")")), "+-*/", false) + while (st.hasMoreTokens) { + colsBuilder += st.nextToken.trim + } + colsBuilder.result + } + + if (c.contains("MAX") || c.contains("MIN")) { + updatedSchema = updatedSchema + .add(getDataType(colName, colDataTypeMap)) + } else if (c.contains("SUM")) { + // Same as Spark, promote to the largest types to prevent overflows. + // IntegralType: if not Long, promote to Long + // FractionalType: if not Double, promote to Double + // DecimalType.Fixed(precision, scale): + // follow what is done in Sum.resultType, +10 to precision + val dataField = getDataType(colName, colDataTypeMap) + dataField.dataType match { + case DecimalType.Fixed(precision, scale) => + updatedSchema = updatedSchema.add( + dataField.name, DecimalType.bounded(precision + 10, scale), dataField.nullable) + case _: IntegralType => + updatedSchema = updatedSchema.add(dataField.name, LongType, dataField.nullable) + case _ => + updatedSchema = updatedSchema.add(dataField.name, DoubleType, dataField.nullable) + } + } else if (c.contains("AVG")) { // AVG + // Same as Spark, promote to the largest types to prevent overflows. + // DecimalType.Fixed(precision, scale): + // follow what is done in Average.resultType, +4 to precision and scale + // promote to Double for other data types + val dataField = getDataType(colName, colDataTypeMap) + dataField.dataType match { + case DecimalType.Fixed(p, s) => updatedSchema = + updatedSchema.add( + dataField.name, DecimalType.bounded(p + 4, s + 4), dataField.nullable) + case _ => updatedSchema = + updatedSchema.add(dataField.name, DoubleType, dataField.nullable) + } + } else { + updatedSchema = updatedSchema.add(colDataTypeMap.get(c).get) + } + } + } + + private def getDataType( + cols: Array[String], + colDataTypeMap: Map[String, StructField]): StructField = { + if (cols.length == 1) { + colDataTypeMap.get(cols(0)).get + } else { + val map = new java.util.HashMap[Object, Integer] + map.put(ByteType, 0) + map.put(ShortType, 1) + map.put(IntegerType, 2) + map.put(LongType, 3) + map.put(FloatType, 4) + map.put(DecimalType, 5) + map.put(DoubleType, 6) + var colType = colDataTypeMap.get(cols(0)).get + for (i <- 1 until cols.length) { + val dType = colDataTypeMap.get(cols(i)).get + if (dType.dataType.isInstanceOf[DecimalType] + && colType.dataType.isInstanceOf[DecimalType]) { + if (dType.dataType.asInstanceOf[DecimalType].precision + > colType.dataType.asInstanceOf[DecimalType].precision) { + colType = dType + } + } else { + if (map.get(colType.dataType) < map.get(dType.dataType)) { + colType = dType + } + } + } + colType + } } /** @@ -221,6 +394,18 @@ private[jdbc] class JDBCRDD( } } + /** + * A GROUP BY clause representing pushed-down grouping columns. + */ + private def getGroupByClause: String = { + if (aggregation.groupByExpressions.length > 0) { + val quotedColumns = aggregation.groupByExpressions.map(JdbcDialects.get(url).quoteIdentifier) + s"GROUP BY ${quotedColumns.mkString(", ")}" + } else { + "" + } + } + /** * Runs the SQL query against the JDBC driver. * @@ -296,13 +481,15 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + + s" $getGroupByClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() - val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) + + val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, updatedSchema, inputMetrics) CompletionIterator[InternalRow, Iterator[InternalRow]]( new InterruptibleIterator(context, rowsIterator), close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 2f1ee0f23d45a..97a36de9f8b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -249,6 +249,7 @@ private[sql] case class JDBCRelation( jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan + with PrunedFilteredAggregateScan with InsertableRelation { override def sqlContext: SQLContext = sparkSession.sqlContext @@ -275,6 +276,21 @@ private[sql] case class JDBCRelation( jdbcOptions).asInstanceOf[RDD[Row]] } + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + aggregation: Aggregation): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] + JDBCRDD.scanTable( + sparkSession.sparkContext, + schema, + requiredColumns, + filters, + parts, + jdbcOptions, + aggregation).asInstanceOf[RDD[Row]] + } + override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 120fa5288dda9..2104a27bb388c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -73,7 +73,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, - relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => + relation @ DataSourceV2ScanRelation(_, + V1ScanWrapper(scan, translated, pushed, aggregation), output)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw new IllegalArgumentException( @@ -88,6 +89,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat output.toStructType, translated.toSet, pushed.toSet, + aggregation, unsafeRowRDD, v1Relation, tableIdentifier = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 2208e930f6b08..2806f397ecda8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.{AggregateFunc, Aggregation} import org.apache.spark.sql.types.StructType object PushDownUtils extends PredicateHelper { @@ -70,6 +72,48 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down aggregates to the data source reader + * + * @return pushed aggregation. + */ + def pushAggregates( + scanBuilder: ScanBuilder, + aggregates: Seq[AggregateExpression], + groupby: Seq[Expression]): Aggregation = { + scanBuilder match { + case r: SupportsPushDownAggregates => + val translatedAggregates = mutable.ArrayBuffer.empty[sources.AggregateFunc] + // Catalyst aggregate expression that can't be translated to data source aggregates. + val untranslatableExprs = mutable.ArrayBuffer.empty[AggregateExpression] + + for (aggregateExpr <- aggregates) { + val translated = DataSourceStrategy.translateAggregate(aggregateExpr) + if (translated.isEmpty) { + untranslatableExprs += aggregateExpr + } else { + translatedAggregates += translated.get + } + } + + def columnAsString(e: Expression): String = e match { + case AttributeReference(name, _, _, _) => name + case _ => "" + } + + if (untranslatableExprs.isEmpty) { + val groupByCols = groupby.map(columnAsString(_)) + if (!groupByCols.exists(_.isEmpty)) { + r.pushAggregation(Aggregation(translatedAggregates, groupByCols)) + } + } + + r.pushedAggregation + + case _ => Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]) + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d2180566790ac..5eefbdb363913 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,38 +17,129 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.{Scan, V1Scan} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.{AggregateFunc, Aggregation} import org.apache.spark.sql.types.StructType -object V2ScanRelationPushDown extends Rule[LogicalPlan] { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper { + import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case ScanOperation(project, filters, relation: DataSourceV2Relation) => - val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + case Aggregate(groupingExpressions, resultExpressions, child) => + child match { + case ScanOperation(project, filters, relation: DataSourceV2Relation) => + val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + val (pushedFilters, postScanFilters) = pushDownFilter(scanBuilder, filters, relation) + if (postScanFilters.nonEmpty) { + Aggregate(groupingExpressions, resultExpressions, child) + } else { + val aliasMap = getAliasMap(project) + var aggregates = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => + replaceAlias(agg, aliasMap).asInstanceOf[AggregateExpression] + } + }.distinct + aggregates = DataSourceStrategy.normalizeExprs(aggregates, relation.output) + .asInstanceOf[Seq[AggregateExpression]] - val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) - val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = - normalizedFilters.partition(SubqueryExpression.hasSubquery) + val groupingExpressionsWithoutAlias = groupingExpressions.flatMap{ expr => + expr.collect { + case a: AttributeReference => replaceAlias(a, aliasMap) + } + }.distinct + val normalizedgroupingExpressions = + DataSourceStrategy.normalizeExprs(groupingExpressionsWithoutAlias, relation.output) - // `pushedFilters` will be pushed down and evaluated in the underlying data sources. - // `postScanFilters` need to be evaluated after the scan. - // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( - scanBuilder, normalizedFiltersWithoutSubquery) - val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + val aggregation = PushDownUtils.pushAggregates(scanBuilder, aggregates, + normalizedgroupingExpressions) + + val (scan, output, normalizedProjects) = + processFilerAndColumn(scanBuilder, project, postScanFilters, relation) + + logInfo( + s""" + |Pushing operators to ${relation.name} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Pushed Aggregate Functions: ${aggregation.aggregateExpressions.mkString(", ")} + |Pushed Groupby: ${aggregation.groupByExpressions.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val wrappedScan = scan match { + case v1: V1Scan => + val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) + V1ScanWrapper(v1, translated, pushedFilters, aggregation) + case _ => scan + } + if (aggregation.aggregateExpressions.isEmpty) { + val plan = buildLogicalPlan(project, relation, wrappedScan, output, + normalizedProjects, postScanFilters) + Aggregate(groupingExpressions, resultExpressions, plan) + } else { + val aggOutputBuilder = ArrayBuilder.make[AttributeReference] + for (i <- 0 until aggregates.length) { + aggOutputBuilder += AttributeReference( + aggregation.aggregateExpressions(i).toString, aggregates(i).dataType)() + } + val aggOutput = aggOutputBuilder.result + + val newOutputBuilder = ArrayBuilder.make[AttributeReference] + for (col <- aggOutput) { + newOutputBuilder += col + } + for (groupBy <- groupingExpressions) { + newOutputBuilder += groupBy.asInstanceOf[AttributeReference] + } + val newOutput = newOutputBuilder.result + + val r = buildLogicalPlan(newOutput, relation, wrappedScan, newOutput, + normalizedProjects, postScanFilters) + val plan = Aggregate(groupingExpressions, resultExpressions, r) + + var i = 0 + plan.transformExpressions { + case agg: AggregateExpression => + i += 1 + val aggFunction: aggregate.AggregateFunction = { + if (agg.aggregateFunction.isInstanceOf[aggregate.Max]) { + aggregate.Max(aggOutput(i - 1)) + } else if (agg.aggregateFunction.isInstanceOf[aggregate.Min]) { + aggregate.Min(aggOutput(i - 1)) + } else if (agg.aggregateFunction.isInstanceOf[aggregate.Average]) { + aggregate.Average(aggOutput(i - 1)) + } else if (agg.aggregateFunction.isInstanceOf[aggregate.Sum]) { + aggregate.Sum(aggOutput(i - 1)) + } else { + agg.aggregateFunction + } + } + // Aggregate filter is pushed to datasource + agg.copy(aggregateFunction = aggFunction, filter = None) + } + } + } + + case _ => + Aggregate(groupingExpressions, resultExpressions, child) + } + case ScanOperation(project, filters, relation: DataSourceV2Relation) => + val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) + val (pushedFilters, postScanFilters) = pushDownFilter (scanBuilder, filters, relation) + val (scan, output, normalizedProjects) = + processFilerAndColumn(scanBuilder, project, postScanFilters, relation) - val normalizedProjects = DataSourceStrategy - .normalizeExprs(project, relation.output) - .asInstanceOf[Seq[NamedExpression]] - val (scan, output) = PushDownUtils.pruneColumns( - scanBuilder, relation, normalizedProjects, postScanFilters) logInfo( s""" |Pushing operators to ${relation.name} @@ -60,31 +151,72 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { val wrappedScan = scan match { case v1: V1Scan => val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true)) - V1ScanWrapper(v1, translated, pushedFilters) + V1ScanWrapper(v1, translated, pushedFilters, + Aggregation(Seq.empty[AggregateFunc], Seq.empty[String])) + case _ => scan } - val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + buildLogicalPlan(project, relation, wrappedScan, output, normalizedProjects, postScanFilters) + } - val projectionOverSchema = ProjectionOverSchema(output.toStructType) - val projectionFunc = (expr: Expression) => expr transformDown { - case projectionOverSchema(newExpr) => newExpr - } + private def pushDownFilter( + scanBuilder: ScanBuilder, + filters: Seq[Expression], + relation: DataSourceV2Relation): (Seq[sources.Filter], Seq[Expression]) = { + val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output) + val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) = + normalizedFilters.partition(SubqueryExpression.hasSubquery) - val filterCondition = postScanFilters.reduceLeftOption(And) - val newFilterCondition = filterCondition.map(projectionFunc) - val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) - - val withProjection = if (withFilter.output != project) { - val newProjects = normalizedProjects - .map(projectionFunc) - .asInstanceOf[Seq[NamedExpression]] - Project(newProjects, withFilter) - } else { - withFilter - } + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( + scanBuilder, normalizedFiltersWithoutSubquery) + val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery + (pushedFilters, postScanFilters) + } + + private def processFilerAndColumn( + scanBuilder: ScanBuilder, + project: Seq[NamedExpression], + postScanFilters: Seq[Expression], + relation: DataSourceV2Relation): + (Scan, Seq[AttributeReference], Seq[NamedExpression]) = { + val normalizedProjects = DataSourceStrategy + .normalizeExprs(project, relation.output) + .asInstanceOf[Seq[NamedExpression]] + val (scan, output) = PushDownUtils.pruneColumns( + scanBuilder, relation, normalizedProjects, postScanFilters) + (scan, output, normalizedProjects) + } - withProjection + private def buildLogicalPlan( + project: Seq[NamedExpression], + relation: DataSourceV2Relation, + wrappedScan: Scan, + output: Seq[AttributeReference], + normalizedProjects: Seq[NamedExpression], + postScanFilters: Seq[Expression]): LogicalPlan = { + val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output) + val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionFunc = (expr: Expression) => expr transformDown { + case projectionOverSchema(newExpr) => newExpr + } + + val filterCondition = postScanFilters.reduceLeftOption(And) + val newFilterCondition = filterCondition.map(projectionFunc) + val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) + + val withProjection = if (withFilter.output != project) { + val newProjects = normalizedProjects + .map(projectionFunc) + .asInstanceOf[Seq[NamedExpression]] + Project(newProjects, withFilter) + } else { + withFilter + } + withProjection } } @@ -93,6 +225,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { case class V1ScanWrapper( v1Scan: V1Scan, translatedFilters: Seq[sources.Filter], - handledFilters: Seq[sources.Filter]) extends Scan { + handledFilters: Seq[sources.Filter], + pushedAggregates: sources.Aggregation) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index 860232ba84f39..d8c29aeb1921c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -20,13 +20,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation -import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} +import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter, TableScan} import org.apache.spark.sql.types.StructType case class JDBCScan( relation: JDBCRelation, prunedSchema: StructType, - pushedFilters: Array[Filter]) extends V1Scan { + pushedFilters: Array[Filter], + pushedAggregation: Aggregation) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -36,14 +37,15 @@ case class JDBCScan( override def schema: StructType = prunedSchema override def needConversion: Boolean = relation.needConversion override def buildScan(): RDD[Row] = { - relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters) + relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters, pushedAggregation) } }.asInstanceOf[T] } override def description(): String = { super.description() + ", prunedSchema: " + seqToString(prunedSchema) + - ", PushedFilters: " + seqToString(pushedFilters) + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggegates: " + seqToString(pushedAggregation.aggregateExpressions) } private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 270c5b6d92e32..26bd0c613f071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -17,23 +17,26 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.jdbc.JdbcDialects -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.{AggregateFunc, Aggregation, Filter} import org.apache.spark.sql.types.StructType case class JDBCScanBuilder( session: SparkSession, schema: StructType, jdbcOptions: JDBCOptions) - extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns { + extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns + with SupportsPushDownAggregates { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis private var pushedFilter = Array.empty[Filter] + private var pushedAggregations = Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]) + private var prunedSchema = schema override def pushFilters(filters: Array[Filter]): Array[Filter] = { @@ -49,6 +52,17 @@ case class JDBCScanBuilder( override def pushedFilters(): Array[Filter] = pushedFilter + override def pushAggregation(aggregation: Aggregation): Unit = { + if (jdbcOptions.pushDownAggregate) { + val dialect = JdbcDialects.get(jdbcOptions.url) + if (!JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect).isEmpty) { + pushedAggregations = aggregation + } + } + } + + override def pushedAggregation(): Aggregation = pushedAggregations + override def pruneColumns(requiredSchema: StructType): Unit = { // JDBC doesn't support nested column pruning. // TODO (SPARK-32593): JDBC support nested column and nested column pruning. @@ -65,6 +79,7 @@ case class JDBCScanBuilder( val resolver = session.sessionState.conf.resolver val timeZoneId = session.sessionState.conf.sessionLocalTimeZone val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions) - JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter) + JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), + prunedSchema, pushedFilter, pushedAggregation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 63e57c6804e16..ef17792ea877c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -273,6 +273,17 @@ trait PrunedFilteredScan { def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] } +/** + * TODO: add doc + * @since 3.1.0 + */ +trait PrunedFilteredAggregateScan { + def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + aggregation: Aggregation): RDD[Row] +} + /** * A BaseRelation that can be used to insert data into it through the insert method. * If overwrite in insert method is true, the old data in the relation should be overwritten with diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index e8157e552d754..ec0ef9ef94e84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.{avg, lit, sum, udf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -64,6 +64,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { .executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + + " bonus NUMERIC(6, 2))").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() } } @@ -109,6 +122,188 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { checkAnswer(df, Row("mary")) } + test("aggregate pushdown with alias") { + val df1 = spark.table("h2.test.employee") + var query1 = df1.select($"DEPT", $"SALARY".as("value")) + .groupBy($"DEPT") + .agg(sum($"value").as("total")) + .filter($"total" > 1000) + // query1.explain(true) + checkAnswer(query1, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000))) + val decrease = udf { (x: Double, y: Double) => x - y} + var query2 = df1.select($"DEPT", decrease($"SALARY", $"BONUS").as("value"), $"SALARY", $"BONUS") + .groupBy($"DEPT") + .agg(sum($"value"), sum($"SALARY"), sum($"BONUS")) + // query2.explain(true) + checkAnswer(query2, + Seq(Row(1, 16800.00, 19000.00, 2200.00), Row(2, 19500.00, 22000.00, 2500.00), + Row(6, 10800, 12000, 1200))) + + val cols = Seq("a", "b", "c", "d") + val df2 = sql("select * from h2.test.employee").toDF(cols: _*) + val df3 = df2.groupBy().sum("c") + // df3.explain(true) + checkAnswer(df3, Seq(Row(53000.00))) + + val df4 = df2.groupBy($"a").sum("c") + checkAnswer(df4, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000))) + } + + test("scan with aggregate push-down") { + val df1 = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + + " group by DEPT") + // df1.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Aggregate ['DEPT], [unresolvedalias('MAX('SALARY), None), unresolvedalias('MIN('BONUS), None)] + // +- 'Filter ('dept > 0) + // +- 'UnresolvedRelation [h2, test, employee], [] + // + // == Analyzed Logical Plan == + // max(SALARY): int, min(BONUS): int + // Aggregate [DEPT#0], [max(SALARY#2) AS max(SALARY)#6, min(BONUS#3) AS min(BONUS)#7] + // +- Filter (dept#0 > 0) + // +- SubqueryAlias h2.test.employee + // +- RelationV2[DEPT#0, NAME#1, SALARY#2, BONUS#3] test.employee + // + // == Optimized Logical Plan == + // Aggregate [DEPT#0], [max(max(SALARY)#13) AS max(SALARY)#6, min(min(BONUS)#14) AS min(BONUS)#7] + // +- RelationV2[DEPT#0, max(SALARY)#13, min(BONUS)#14] test.employee + // + // == Physical Plan == + // *(2) HashAggregate(keys=[DEPT#0], functions=[max(max(SALARY)#13), min(min(BONUS)#14)], output=[max(SALARY)#6, min(BONUS)#7]) + // +- Exchange hashpartitioning(DEPT#0, 5), true, [id=#10] + // +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@3d9f0a5 [DEPT#0,max(SALARY)#13,min(BONUS)#14] PushedAggregates: [*Max(SALARY,false,None), *Min(BONUS,false,None)], PushedFilters: [IsNotNull(dept), GreaterThan(dept,0)], PushedGroupby: [*DEPT], ReadSchema: struct// scalastyle:on line.size.limit + // + // df1.show + // +-----------+----------+ + // |max(SALARY)|min(BONUS)| + // +-----------+----------+ + // | 10000| 1000| + // | 12000| 1200| + // | 12000| 1200| + // +-----------+----------+ + checkAnswer(df1, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200))) + + val df2 = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0") + // df2.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Project [unresolvedalias('MAX('ID), None), unresolvedalias('MIN('ID), None)] + // +- 'Filter ('id > 0) + // +- 'UnresolvedRelation [h2, test, people], [] + // + // == Analyzed Logical Plan == + // max(ID): int, min(ID): int + // Aggregate [max(ID#29) AS max(ID)#32, min(ID#29) AS min(ID)#33] + // +- Filter (id#29 > 0) + // +- SubqueryAlias h2.test.people + // +- RelationV2[NAME#28, ID#29] test.people + // + // == Optimized Logical Plan == + // Aggregate [max(max(ID)#37) AS max(ID)#32, min(min(ID)#38) AS min(ID)#33] + // +- RelationV2[max(ID)#37, min(ID)#38] test.people + // + // == Physical Plan == + // *(2) HashAggregate(keys=[], functions=[max(max(ID)#37), min(min(ID)#38)], output=[max(ID)#32, min(ID)#33]) + // +- Exchange SinglePartition, true, [id=#44] + // +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@5ed31735 [max(ID)#37,min(ID)#38] PushedAggregates: [*Max(ID,false,None), *Min(ID,false,None)], PushedFilters: [IsNotNull(id), GreaterThan(id,0)], PushedGroupby: [], ReadSchema: struct + // scalastyle:on line.size.limit + + // df2.show() + // +-------+-------+ + // |max(ID)|min(ID)| + // +-------+-------+ + // | 2| 1| + // +-------+-------+ + checkAnswer(df2, Seq(Row(2, 1))) + + val df3 = sql("select AVG(ID) FROM h2.test.people where id > 0") + checkAnswer(df3, Seq(Row(1.0))) + + val df4 = sql("select MAX(SALARY) + 1 FROM h2.test.employee") + // df4.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Project [unresolvedalias(('MAX('SALARY) + 1), None)] + // +- 'UnresolvedRelation [h2, test, employee], [] + // + // == Analyzed Logical Plan == + // (max(SALARY) + 1): int + // Aggregate [(max(SALARY#68) + 1) AS (max(SALARY) + 1)#71] + // +- SubqueryAlias h2.test.employee + // +- RelationV2[DEPT#66, NAME#67, SALARY#68, BONUS#69] test.employee + // + // == Optimized Logical Plan == + // Aggregate [(max((max(SALARY) + 1)#74) + 1) AS (max(SALARY) + 1)#71] + // +- RelationV2[(max(SALARY) + 1)#74] test.employee + // + // == Physical Plan == + // *(2) HashAggregate(keys=[], functions=[max((max(SALARY) + 1)#74)], output=[(max(SALARY) + 1)#71]) + // +- Exchange SinglePartition, true, [id=#112] + // +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@20864cd1 [(max(SALARY) + 1)#74] PushedAggregates: [*Max(SALARY,false,None)], PushedFilters: [], PushedGroupby: [], ReadSchema: struct<(max(SALARY) + 1):int> + // scalastyle:on line.size.limit + checkAnswer(df4, Seq(Row(12001))) + + // COUNT push down is not supported yet + val df5 = sql("select COUNT(*) FROM h2.test.employee") + // df5.explain(true) + // scalastyle:off line.size.limit + // == Parsed Logical Plan == + // 'Project [unresolvedalias('COUNT(1), None)] + // +- 'UnresolvedRelation [h2, test, employee], [] + // + // == Analyzed Logical Plan == + // count(1): bigint + // Aggregate [count(1) AS count(1)#87L] + // +- SubqueryAlias h2.test.employee + // +- RelationV2[DEPT#82, NAME#83, SALARY#84, BONUS#85] test.employee + // + // == Optimized Logical Plan == + // Aggregate [count(1) AS count(1)#87L] + // +- RelationV2[] test.employee + // + // == Physical Plan == + // *(2) HashAggregate(keys=[], functions=[count(1)], output=[count(1)#87L]) + // *(2) HashAggregate(keys=[], functions=[count(1)], output=[count(1)#87L]) + // +- Exchange SinglePartition, true, [id=#149] + // +- *(1) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#90L]) + // +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$1@63262071 [] PushedAggregates: [], PushedFilters: [], PushedGroupby: [], ReadSchema: struct<> + // scalastyle:on line.size.limit + checkAnswer(df5, Seq(Row(5))) + + val df6 = sql("select MIN(SALARY), MIN(BONUS), MIN(SALARY) * MIN(BONUS) FROM h2.test.employee") + // df6.explain(true) + checkAnswer(df6, Seq(Row(9000, 1000, 9000000))) + + val df7 = sql("select MIN(salary), MIN(bonus), SUM(SALARY * BONUS) FROM h2.test.employee") + // df7.explain(true) + checkAnswer(df7, Seq(Row(9000, 1000, 62600000))) + + val df8 = sql("select BONUS, SUM(SALARY+BONUS), SALARY FROM h2.test.employee" + + " GROUP BY SALARY, BONUS") + // df8.explain(true) + checkAnswer(df8, Seq(Row(1000, 11000, 10000), Row(1200, 26400, 12000), + Row(1200, 10200, 9000), Row(1300, 11300, 10000))) + + val df9 = spark.table("h2.test.employee") + val sub2 = udf { (x: String) => x.substring(0, 3) } + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val df10 = df9.select($"SALARY", $"BONUS", sub2($"NAME").as("nsub2")) + .filter("SALARY > 100") + .filter(name($"nsub2")) + .agg(avg($"SALARY").as("avg_salary")) + // df10.explain(true) + checkAnswer(df10, Seq(Row(9666.666667))) + } + + test("scan with aggregate distinct push-down") { + checkAnswer(sql("SELECT SUM(SALARY) FROM h2.test.employee"), Seq(Row(53000))) + checkAnswer(sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee"), Seq(Row(31000))) + checkAnswer(sql("SELECT AVG(DEPT) FROM h2.test.employee"), Seq(Row(2))) + checkAnswer(sql("SELECT AVG(DISTINCT DEPT) FROM h2.test.employee"), Seq(Row(3))) + } + test("read/write with partition info") { withTable("h2.test.abc") { sql("CREATE TABLE h2.test.abc AS SELECT * FROM h2.test.people") @@ -145,7 +340,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession { test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), - Seq(Row("test", "people"), Row("test", "empty_table"))) + Seq(Row("test", "people"), Row("test", "empty_table"), Row("test", "employee"))) } test("SQL API: create table as select") {