diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 8951c37e127b6..f1e0e6d80c561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -26,12 +26,12 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} -import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum, UserDefinedAggregateFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { @@ -44,6 +44,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit pushDownFilters, pushDownAggregates, pushDownLimitAndOffset, + buildScanWithPushedAggregate, pruneColumns) pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) => @@ -92,189 +93,201 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform { // update the scan builder with agg pushdown and return a new plan with agg pushed - case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => - child match { - case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && CollapseProject.canCollapseExpressions( - resultExpressions, project, alwaysInline = true) => - sHolder.builder match { - case r: SupportsPushDownAggregates => - val aliasMap = getAliasMap(project) - val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap)) - val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap)) - - val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) - val normalizedAggregates = DataSourceStrategy.normalizeExprs( - aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] - val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( - actualGroupExprs, sHolder.relation.output) - val translatedAggregates = DataSourceStrategy.translateAggregation( - normalizedAggregates, normalizedGroupingExpressions) - val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { - if (translatedAggregates.isEmpty || - r.supportCompletePushDown(translatedAggregates.get) || - translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { - (actualResultExprs, aggregates, translatedAggregates) - } else { - // scalastyle:off - // The data source doesn't support the complete push-down of this aggregation. - // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be - // pushed, completely or partially. - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT avg(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] - // +- ScanOperation[...] - // - // After convert avg(c1#9) to sum(c1#9)/count(c1#9) - // we have the following - // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] - // +- ScanOperation[...] - // scalastyle:on - val newResultExpressions = actualResultExprs.map { expr => - expr.transform { - case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => - val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) - val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) - avg.evaluateExpression transform { - case a: Attribute if a.semanticEquals(avg.sum) => - addCastIfNeeded(sum, avg.sum.dataType) - case a: Attribute if a.semanticEquals(avg.count) => - addCastIfNeeded(count, avg.count.dataType) - } - } - }.asInstanceOf[Seq[NamedExpression]] - // Because aggregate expressions changed, translate them again. - aggExprToOutputOrdinal.clear() - val newAggregates = - collectAggregates(newResultExpressions, aggExprToOutputOrdinal) - val newNormalizedAggregates = DataSourceStrategy.normalizeExprs( - newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] - (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation( - newNormalizedAggregates, normalizedGroupingExpressions)) + case agg: Aggregate => rewriteAggregate(agg) + } + + private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match { + case ScanOperation(project, Nil, holder @ ScanBuilderHolder(_, _, + r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions( + agg.aggregateExpressions, project, alwaysInline = true) => + val aliasMap = getAliasMap(project) + val actualResultExprs = agg.aggregateExpressions.map(replaceAliasButKeepName(_, aliasMap)) + val actualGroupExprs = agg.groupingExpressions.map(replaceAlias(_, aliasMap)) + + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) + val normalizedAggExprs = DataSourceStrategy.normalizeExprs( + aggregates, holder.relation.output).asInstanceOf[Seq[AggregateExpression]] + val normalizedGroupingExpr = DataSourceStrategy.normalizeExprs( + actualGroupExprs, holder.relation.output) + val translatedAggOpt = DataSourceStrategy.translateAggregation( + normalizedAggExprs, normalizedGroupingExpr) + if (translatedAggOpt.isEmpty) { + // Cannot translate the catalyst aggregate, return the query plan unchanged. + return agg + } + + val (finalResultExprs, finalAggExprs, translatedAgg, canCompletePushDown) = { + if (r.supportCompletePushDown(translatedAggOpt.get)) { + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, true) + } else if (!translatedAggOpt.get.aggregateExpressions().exists(_.isInstanceOf[Avg])) { + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, false) + } else { + // scalastyle:off + // The data source doesn't support the complete push-down of this aggregation. + // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be + // pushed, completely or partially. + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT avg(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // + // After convert avg(c1#9) to sum(c1#9)/count(c1#9) + // we have the following + // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // scalastyle:on + val newResultExpressions = actualResultExprs.map { expr => + expr.transform { + case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => + val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) + val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) + avg.evaluateExpression transform { + case a: Attribute if a.semanticEquals(avg.sum) => + addCastIfNeeded(sum, avg.sum.dataType) + case a: Attribute if a.semanticEquals(avg.count) => + addCastIfNeeded(count, avg.count.dataType) } - } + } + }.asInstanceOf[Seq[NamedExpression]] + // Because aggregate expressions changed, translate them again. + aggExprToOutputOrdinal.clear() + val newAggregates = + collectAggregates(newResultExpressions, aggExprToOutputOrdinal) + val newNormalizedAggExprs = DataSourceStrategy.normalizeExprs( + newAggregates, holder.relation.output).asInstanceOf[Seq[AggregateExpression]] + val newTranslatedAggOpt = DataSourceStrategy.translateAggregation( + newNormalizedAggExprs, normalizedGroupingExpr) + if (newTranslatedAggOpt.isEmpty) { + // Ideally we should never reach here. But if we end up with not able to translate + // new aggregate with AVG replaced by SUM/COUNT, revert to the original one. + (actualResultExprs, normalizedAggExprs, translatedAggOpt.get, false) + } else { + (newResultExpressions, newNormalizedAggExprs, newTranslatedAggOpt.get, + r.supportCompletePushDown(newTranslatedAggOpt.get)) + } + } + } - if (finalTranslatedAggregates.isEmpty) { - aggNode // return original plan node - } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) && - !supportPartialAggPushDown(finalTranslatedAggregates.get)) { - aggNode // return original plan node - } else { - val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) - if (pushedAggregates.isEmpty) { - aggNode // return original plan node - } else { - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. - val scan = sHolder.builder.build() - - // scalastyle:off - // use the group by columns and aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] - // scalastyle:on - val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + finalAggregates.length) - val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map { - case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId) - case ((expr, attr), ordinal) => - if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { - groupByExprToOutputOrdinal(expr.canonicalized) = ordinal - } - attr - } - val aggOutput = newOutput.drop(groupAttrs.length) - val output = groupAttrs ++ aggOutput - - logInfo( - s""" - |Pushing operators to ${sHolder.relation.name} - |Pushed Aggregate Functions: - | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} - |Pushed Group by: - | ${pushedAggregates.get.groupByExpressions.mkString(", ")} - |Output: ${output.mkString(", ")} - """.stripMargin) - - val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) - val scanRelation = - DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = finalResultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val child = - addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) - Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) - case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => - val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) - addCastIfNeeded(groupAttrs(ordinal), expr.dataType) - } - }.asInstanceOf[Seq[NamedExpression]] - Project(projectExpressions, scanRelation) + if (!canCompletePushDown && !supportPartialAggPushDown(translatedAgg)) { + return agg + } + if (!r.pushAggregation(translatedAgg)) { + return agg + } + + // scalastyle:off + // We name the output columns of group expressions and aggregate functions by + // ordinal: `group_col_0`, `group_col_1`, ..., `agg_func_0`, `agg_func_1`, ... + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // Use group_col_0, agg_func_0, agg_func_1 as output for ScanBuilderHolder. + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [group_col_0#10], [min(agg_func_0#21) AS min(c1)#17, max(agg_func_1#22) AS max(c1)#18] + // +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22] + // Later, we build the `Scan` instance and convert ScanBuilderHolder to DataSourceV2ScanRelation. + // scalastyle:on + val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => + AttributeReference(s"group_col_$i", e.dataType)() + } + val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) => + AttributeReference(s"agg_func_$i", e.dataType)() + } + val newOutput = groupOutput ++ aggOutput + val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) => + if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { + groupByExprToOutputOrdinal(expr.canonicalized) = ordinal + } + } + + holder.pushedAggregate = Some(translatedAgg) + holder.output = newOutput + logInfo( + s""" + |Pushing operators to ${holder.relation.name} + |Pushed Aggregate Functions: + | ${translatedAgg.aggregateExpressions().mkString(", ")} + |Pushed Group by: + | ${translatedAgg.groupByExpressions.mkString(", ")} + """.stripMargin) + + if (canCompletePushDown) { + val projectExpressions = finalResultExprs.map { expr => + expr.transformDown { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + Alias(aggOutput(ordinal), agg.resultAttribute.name)(agg.resultAttribute.exprId) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + expr match { + case ne: NamedExpression => Alias(groupOutput(ordinal), ne.name)(ne.exprId) + case _ => groupOutput(ordinal) + } + } + }.asInstanceOf[Seq[NamedExpression]] + Project(projectExpressions, holder) + } else { + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + val aggExprs = finalResultExprs.map(_.transform { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val aggAttribute = aggOutput(ordinal) + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case max: aggregate.Max => + max.copy(child = aggAttribute) + case min: aggregate.Min => + min.copy(child = aggAttribute) + case sum: aggregate.Sum => + // To keep the dataType of `Sum` unchanged, we need to cast the + // data-source-aggregated result to `Sum.child.dataType` if it's decimal. + // See `SumBase.resultType` + val newChild = if (sum.dataType.isInstanceOf[DecimalType]) { + addCastIfNeeded(aggAttribute, sum.child.dataType) } else { - val plan = Aggregate(output.take(groupingExpressions.length), - finalResultExpressions, scanRelation) - - // scalastyle:off - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9, c2#10] ... - // - // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] - // we have the following - // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // scalastyle:on - plan.transformExpressions { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val aggAttribute = aggOutput(ordinal) - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case max: aggregate.Max => - max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType)) - case min: aggregate.Min => - min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType)) - case sum: aggregate.Sum => - sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType)) - case _: aggregate.Count => - aggregate.Sum(addCastIfNeeded(aggAttribute, LongType)) - case other => other - } - agg.copy(aggregateFunction = aggFunction) - case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => - val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) - addCastIfNeeded(groupAttrs(ordinal), expr.dataType) - } + aggAttribute } - } + sum.copy(child = newChild) + case _: aggregate.Count => + aggregate.Sum(aggAttribute) + case other => other } - case _ => aggNode - } - case _ => aggNode + agg.copy(aggregateFunction = aggFunction) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + expr match { + case ne: NamedExpression => Alias(groupOutput(ordinal), ne.name)(ne.exprId) + case _ => groupOutput(ordinal) + } + }).asInstanceOf[Seq[NamedExpression]] + Aggregate(groupOutput, aggExprs, holder) } + + case _ => agg } - private def collectAggregates(resultExpressions: Seq[NamedExpression], + private def collectAggregates( + resultExpressions: Seq[NamedExpression], aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { var ordinal = 0 resultExpressions.flatMap { expr => @@ -292,15 +305,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } private def supportPartialAggPushDown(agg: Aggregation): Boolean = { - // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. - // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down. - agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().exists { + // We can only partially push down min/max/sum/count without DISTINCT. + agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().forall { case sum: Sum => !sum.isDistinct case count: Count => !count.isDistinct - case avg: Avg => !avg.isDistinct - case _: GeneralAggregateFunc => false - case _: UserDefinedAggregateFunc => false - case _ => true + case _: Min | _: Max | _: CountStar => true + case _ => false } } @@ -311,6 +321,26 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit Cast(expression, expectedDataType) } + def buildScanWithPushedAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { + case holder: ScanBuilderHolder if holder.pushedAggregate.isDefined => + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. + val scan = holder.builder.build() + val realOutput = scan.readSchema().toAttributes + assert(realOutput.length == holder.output.length, + "The data source returns unexpected number of columns") + val wrappedScan = getWrappedScan(scan, holder) + val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput) + val projectList = realOutput.zip(holder.output).map { case (a1, a2) => + // The data source may return columns with arbitrary data types and it's safer to cast them + // to the expected data type. + assert(Cast.canCast(a1.dataType, a2.dataType)) + Alias(addCastIfNeeded(a1, a2.dataType), a2.name)(a2.exprId) + } + Project(projectList, scanRelation) + } + def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning @@ -325,7 +355,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit |Output: ${output.mkString(", ")} """.stripMargin) - val wrappedScan = getWrappedScan(scan, sHolder, Option.empty[Aggregation]) + val wrappedScan = getWrappedScan(scan, sHolder) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) @@ -378,8 +408,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } (operation, isPushed && !isPartiallyPushed) case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) - if filter.isEmpty && CollapseProject.canCollapseExpressions( - order, project, alwaysInline = true) => + // Without building the Scan, we do not know the resulting column names after aggregate + // push-down, and thus can't push down Top-N which needs to know the ordering column names. + // TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same + // columns, which we know the resulting column names: the original table columns. + if sHolder.pushedAggregate.isEmpty && filter.isEmpty && + CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] val normalizedOrders = DataSourceStrategy.normalizeExprs( @@ -480,10 +514,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } } - private def getWrappedScan( - scan: Scan, - sHolder: ScanBuilderHolder, - aggregation: Option[Aggregation]): Scan = { + private def getWrappedScan(scan: Scan, sHolder: ScanBuilderHolder): Scan = { scan match { case v1: V1Scan => val pushedFilters = sHolder.builder match { @@ -491,7 +522,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit f.pushedFilters() case _ => Array.empty[sources.Filter] } - val pushedDownOperators = PushedDownOperators(aggregation, sHolder.pushedSample, + val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate, sHolder.pushedSample, sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders, sHolder.pushedPredicates) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan @@ -500,7 +531,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit } case class ScanBuilderHolder( - output: Seq[AttributeReference], + var output: Seq[AttributeReference], relation: DataSourceV2Relation, builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None @@ -512,6 +543,8 @@ case class ScanBuilderHolder( var pushedSample: Option[TableSampleInfo] = None var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] + + var pushedAggregate: Option[Aggregation] = None } // A wrapper for v1 scan to carry the translated filters and the handled ones, along with diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index ddcf28652e91d..f0fcb27307d27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -265,9 +265,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) - checkLimitRemoved(df4, false) + checkAggregateRemoved(df4) + checkLimitRemoved(df4) checkPushedInfo(df4, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 1") checkAnswer(df4, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -340,9 +344,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .offset(1) - checkOffsetRemoved(df5, false) + checkAggregateRemoved(df5) + checkLimitRemoved(df5) checkPushedInfo(df5, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1") checkAnswer(df5, Seq(Row(2, 22000.00), Row(6, 12000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -477,10 +485,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy("DEPT").sum("SALARY") .limit(2) .offset(1) - checkLimitRemoved(df10, false) - checkOffsetRemoved(df10, false) + checkAggregateRemoved(df10) + checkLimitRemoved(df10) + checkOffsetRemoved(df10) checkPushedInfo(df10, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 2", + "PushedOffset: OFFSET 1") checkAnswer(df10, Seq(Row(2, 22000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -612,10 +625,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df9, Seq(Row(2, "david", 10000.00, 1300.0, true))) val df10 = sql("SELECT dept, sum(salary) FROM h2.test.employee group by dept LIMIT 1 OFFSET 1") - checkLimitRemoved(df10, false) - checkOffsetRemoved(df10, false) + checkAggregateRemoved(df10) + checkLimitRemoved(df10) + checkOffsetRemoved(df10) checkPushedInfo(df10, - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByExpressions: [DEPT], ") + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedLimit: LIMIT 2", + "PushedOffset: OFFSET 1") checkAnswer(df10, Seq(Row(2, 22000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") }