Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.io.Serializable;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.Expression;

/**
* Aggregation in SQL statement.
Expand All @@ -30,14 +30,14 @@
@Evolving
public final class Aggregation implements Serializable {
private final AggregateFunc[] aggregateExpressions;
private final NamedReference[] groupByColumns;
private final Expression[] groupByExpressions;

public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] groupByColumns) {
public Aggregation(AggregateFunc[] aggregateExpressions, Expression[] groupByExpressions) {
this.aggregateExpressions = aggregateExpressions;
this.groupByColumns = groupByColumns;
this.groupByExpressions = groupByExpressions;
}

public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }

public NamedReference[] groupByColumns() { return groupByColumns; }
public Expression[] groupByExpressions() { return groupByExpressions; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ case class RowDataSourceScanExec(
"PushedFilters" -> pushedFilters) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())),
"PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++
"PushedGroupByExpressions" -> seqToString(v.groupByExpressions.map(_.describe())))} ++
topNOrLimitInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow}
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.execution.RowToColumnConverter
import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
Expand Down Expand Up @@ -93,8 +94,8 @@ object AggregatePushDownUtils {
return None
}

if (aggregation.groupByColumns.nonEmpty &&
partitionNames.size != aggregation.groupByColumns.length) {
if (aggregation.groupByExpressions.nonEmpty &&
partitionNames.size != aggregation.groupByExpressions.length) {
// If there are group by columns, we only push down if the group by columns are the same as
// the partition columns. In theory, if group by columns are a subset of partition columns,
// we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3,
Expand All @@ -106,11 +107,11 @@ object AggregatePushDownUtils {
// aggregate push down simple and don't handle this complicate case for now.
return None
}
aggregation.groupByColumns.foreach { col =>
aggregation.groupByExpressions.map(extractColName).foreach { colName =>
// don't push down if the group by columns are not the same as the partition columns (orders
// doesn't matter because reorder can be done at data source layer)
if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None
finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head))
if (colName.isEmpty || !isPartitionCol(colName.get)) return None
finalSchema = finalSchema.add(getStructFieldForCol(colName.get))
}

aggregation.aggregateExpressions.foreach {
Expand All @@ -137,7 +138,8 @@ object AggregatePushDownUtils {
def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
a.aggregateExpressions.sortBy(_.hashCode())
.sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
a.groupByExpressions.sortBy(_.hashCode())
.sameElements(b.groupByExpressions.sortBy(_.hashCode()))
}

/**
Expand All @@ -164,7 +166,7 @@ object AggregatePushDownUtils {
def getSchemaWithoutGroupingExpression(
aggSchema: StructType,
aggregation: Aggregation): StructType = {
val numOfGroupByColumns = aggregation.groupByColumns.length
val numOfGroupByColumns = aggregation.groupByExpressions.length
if (numOfGroupByColumns > 0) {
new StructType(aggSchema.fields.drop(numOfGroupByColumns))
} else {
Expand All @@ -179,7 +181,7 @@ object AggregatePushDownUtils {
partitionSchema: StructType,
aggregation: Aggregation,
partitionValues: InternalRow): InternalRow = {
val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head)
val groupByColNames = aggregation.groupByExpressions.flatMap(extractColName)
assert(groupByColNames.length == partitionSchema.length &&
groupByColNames.length == partitionValues.numFields, "The number of group by columns " +
s"${groupByColNames.length} should be the same as partition schema length " +
Expand All @@ -197,4 +199,9 @@ object AggregatePushDownUtils {
partitionValues
}
}

private def extractColName(v2Expr: V2Expression): Option[String] = v2Expr match {
case f: FieldReference if f.fieldNames.length == 1 => Some(f.fieldNames.head)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -759,14 +759,13 @@ object DataSourceStrategy
protected[sql] def translateAggregation(
aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {

def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Some(FieldReference.column(name).asInstanceOf[FieldReference])
def translateGroupBy(e: Expression): Option[V2Expression] = e match {
case PushableExpression(expr) => Some(expr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to confirm: this supports nested column, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PushableExpression uses PushableColumnWithoutNestedColumn in default.

case _ => None
}

val translatedAggregates = aggregates.flatMap(translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)
val translatedGroupBys = groupBy.flatMap(translateGroupBy)

if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ object OrcUtils extends Logging {
val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy,
(0 until schemaWithoutGroupBy.length).toArray)
val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues)
if (aggregation.groupByColumns.nonEmpty) {
if (aggregation.groupByExpressions.nonEmpty) {
val reOrderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
partitionSchema, aggregation, partitionValues)
new JoinedRow(reOrderedPartitionValues, resultRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ object ParquetUtils {
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
}

if (aggregation.groupByColumns.nonEmpty) {
if (aggregation.groupByExpressions.nonEmpty) {
val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
partitionSchema, aggregation, partitionValues)
new JoinedRow(reorderedPartitionValues, converter.currentRecord)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map {
case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId)
case ((expr, attr), ordinal) =>
if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
groupByExprToOutputOrdinal(expr.canonicalized) = ordinal
}
attr
}
val aggOutput = newOutput.drop(groupAttrs.length)
val output = groupAttrs ++ aggOutput
Expand All @@ -197,7 +202,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
|Pushed Aggregate Functions:
| ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
|Pushed Group by:
| ${pushedAggregates.get.groupByColumns.mkString(", ")}
| ${pushedAggregates.get.groupByExpressions.mkString(", ")}
|Output: ${output.mkString(", ")}
""".stripMargin)

Expand All @@ -206,14 +211,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
if (r.supportCompletePushDown(pushedAggregates.get)) {
val projectExpressions = finalResultExpressions.map { expr =>
// TODO At present, only push down group by attribute is supported.
// In future, more attribute conversion is extended here. e.g. GetStructField
expr.transform {
expr.transformDown {
case agg: AggregateExpression =>
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
val child =
addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
}
}.asInstanceOf[Seq[NamedExpression]]
Project(projectExpressions, scanRelation)
Expand Down Expand Up @@ -256,6 +262,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
case other => other
}
agg.copy(aggregateFunction = aggFunction)
case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
Expand Down Expand Up @@ -70,12 +70,15 @@ case class JDBCScanBuilder(

private var pushedAggregateList: Array[String] = Array()

private var pushedGroupByCols: Option[Array[String]] = None
private var pushedGroupBys: Option[Array[String]] = None

override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
lazy val fieldNames = aggregation.groupByExpressions()(0) match {
case field: FieldReference => field.fieldNames
case _ => Array.empty[String]
}
jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
(aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
(aggregation.groupByExpressions().length == 1 && fieldNames.length == 1 &&
jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
}

Expand All @@ -86,28 +89,26 @@ case class JDBCScanBuilder(
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
if (compiledAggs.length != aggregation.aggregateExpressions.length) return false

val groupByCols = aggregation.groupByColumns.map { col =>
if (col.fieldNames.length != 1) return false
dialect.quoteIdentifier(col.fieldNames.head)
}
val compiledGroupBys = aggregation.groupByExpressions.flatMap(dialect.compileExpression)
if (compiledGroupBys.length != aggregation.groupByExpressions.length) return false

// The column names here are already quoted and can be used to build sql string directly.
// e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
// SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
// GROUP BY "DEPT", "NAME"
val selectList = groupByCols ++ compiledAggs
val groupByClause = if (groupByCols.isEmpty) {
val selectList = compiledGroupBys ++ compiledAggs
val groupByClause = if (compiledGroupBys.isEmpty) {
""
} else {
"GROUP BY " + groupByCols.mkString(",")
"GROUP BY " + compiledGroupBys.mkString(",")
}

val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " +
s"WHERE 1=0 $groupByClause"
try {
finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect)
pushedAggregateList = selectList
pushedGroupByCols = Some(groupByCols)
pushedGroupBys = Some(compiledGroupBys)
true
} catch {
case NonFatal(e) =>
Expand Down Expand Up @@ -173,6 +174,6 @@ case class JDBCScanBuilder(
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
// be used in sql string.
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate,
pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders)
pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ case class OrcScan(

lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
(seqToString(pushedAggregate.get.aggregateExpressions),
seqToString(pushedAggregate.get.groupByColumns))
seqToString(pushedAggregate.get.groupByExpressions))
} else {
("[]", "[]")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ case class ParquetScan(

lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
(seqToString(pushedAggregate.get.aggregateExpressions),
seqToString(pushedAggregate.get.groupByColumns))
seqToString(pushedAggregate.get.groupByExpressions))
} else {
("[]", "[]")
}
Expand Down
Loading