Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.Aggregation;

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down aggregates to the data source.
*
* @since 3.1.0
*/
@Evolving
public interface SupportsPushDownAggregates extends ScanBuilder {

/**
* Pushes down Aggregation to datasource.
* The Aggregation can be pushed down only if all the Aggregate Functions can
* be pushed down.
*/
void pushAggregation(Aggregation aggregation);

/**
* Returns the aggregates that are pushed to the data source via
* {@link #pushAggregation(Aggregation aggregation)}.
*/
Aggregation pushedAggregation();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources

case class Aggregation(aggregateExpressions: Seq[AggregateFunc],
groupByExpressions: Seq[String])

abstract class AggregateFunc

// Todo: add Count

case class Avg(column: String, isDistinct: Boolean) extends AggregateFunc
case class Min(column: String) extends AggregateFunc
case class Max(column: String) extends AggregateFunc
case class Sum(column: String, isDistinct: Boolean) extends AggregateFunc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, Filter}
import org.apache.spark.sql.sources.{Aggregation, BaseRelation, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -102,6 +102,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
aggregation: Aggregation,
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
Expand Down Expand Up @@ -132,9 +133,17 @@ case class RowDataSourceScanExec(
val markedFilters = for (filter <- filters) yield {
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
}
val markedAggregates = for (aggregate <- aggregation.aggregateExpressions) yield {
s"*$aggregate"
}
val markedGroupby = for (groupby <- aggregation.groupByExpressions) yield {
s"*$groupby"
}
Map(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"),
"PushedAggregates" -> markedAggregates.mkString("[", ", ", "]"),
"PushedGroupby" -> markedGroupby.mkString("[", ", ", "]"))
}

// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{CacheTable, InsertIntoDir, InsertIntoStatement, LogicalPlan, Project, UncacheTable}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -358,6 +359,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
Expand Down Expand Up @@ -431,6 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand All @@ -453,6 +456,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
Aggregation(Seq.empty[AggregateFunc], Seq.empty[String]),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand Down Expand Up @@ -700,6 +704,101 @@ object DataSourceStrategy
(nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
}

protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {

def columnAsString(e: Expression): String = e match {
case AttributeReference(name, _, _, _) => name
case Cast(child, _, _) => child match {
case AttributeReference(name, _, _, _) => name
case Add(left, right, _) =>
arithmeticExpressionAsString(left, right, "+")
case Subtract(left, right, _) =>
arithmeticExpressionAsString(left, right, "-")
case Multiply(left, right, _) =>
arithmeticExpressionAsString(left, right, "*")
case Divide(left, right, _) =>
arithmeticExpressionAsString(left, right, "/")
case _ => ""
}
case Add(left, right, _) =>
arithmeticExpressionAsString(left, right, "+")
case Subtract(left, right, _) =>
arithmeticExpressionAsString(left, right, "-")
case Multiply(left, right, _) =>
arithmeticExpressionAsString(left, right, "*")
case Divide(left, right, _) =>
arithmeticExpressionAsString(left, right, "/")
case CheckOverflow(child, _, _) => child match {
case Add(left, right, _) =>
arithmeticExpressionAsString(left, right, "+")
case Subtract(left, right, _) =>
arithmeticExpressionAsString(left, right, "-")
case Multiply(left, right, _) =>
arithmeticExpressionAsString(left, right, "*")
case Divide(left, right, _) =>
arithmeticExpressionAsString(left, right, "/")
case _ => ""
}
case _ => ""
}

aggregates.aggregateFunction match {
case aggregate.Min(child) =>
val columnName = columnAsString(child)
if (!columnName.isEmpty) Some(Min(columnName)) else None
case aggregate.Max(child) =>
val columnName = columnAsString(child)
if (!columnName.isEmpty) Some(Max(columnName)) else None
case aggregate.Average(child) =>
val columnName = columnAsString(child)
if (!columnName.isEmpty) Some(Avg(columnName, aggregates.isDistinct)) else None
case aggregate.Sum(child) =>
val columnName = columnAsString(child)
if (!columnName.isEmpty) Some(Sum(columnName, aggregates.isDistinct)) else None
case _ => None
}
}

private def arithmeticExpressionAsString (
left: Expression,
right: Expression,
sign: String): String = {

val leftName = if (left.isInstanceOf[AttributeReference]) {
left.asInstanceOf[AttributeReference].name
} else if (left.isInstanceOf[Cast]) {
if (left.asInstanceOf[Cast].child.isInstanceOf[AttributeReference]) {
left.asInstanceOf[Cast].child.asInstanceOf[AttributeReference].name
}
} else if (left.isInstanceOf[PromotePrecision]) {
if (left.asInstanceOf[PromotePrecision].child.isInstanceOf[AttributeReference]) {
left.asInstanceOf[PromotePrecision].child.asInstanceOf[AttributeReference].name
} else if (left.asInstanceOf[PromotePrecision].child.isInstanceOf[Cast]) {
if (left.asInstanceOf[PromotePrecision].child.isInstanceOf[Cast]) {
left.asInstanceOf[PromotePrecision].child.asInstanceOf[Cast]
.child.asInstanceOf[AttributeReference].name
}
}
}
val rightName = if (right.isInstanceOf[AttributeReference]) {
right.asInstanceOf[AttributeReference].name
} else if (right.isInstanceOf[Cast]) {
if (right.asInstanceOf[Cast].child.isInstanceOf[AttributeReference]) {
right.asInstanceOf[Cast].child.asInstanceOf[AttributeReference].name
}
} else if (right.isInstanceOf[PromotePrecision]) {
if (right.asInstanceOf[PromotePrecision].child.isInstanceOf[AttributeReference]) {
right.asInstanceOf[PromotePrecision].child.asInstanceOf[AttributeReference].name
} else if (right.asInstanceOf[PromotePrecision].child.isInstanceOf[Cast]) {
if (right.asInstanceOf[PromotePrecision].child.isInstanceOf[Cast]) {
right.asInstanceOf[PromotePrecision].child.asInstanceOf[Cast]
.child.asInstanceOf[AttributeReference].name
}
}
}
s"$leftName $sign $rightName"
}

/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ class JDBCOptions(
// An option to allow/disallow pushing down predicate into JDBC data source
val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean

// An option to allow/disallow pushing down aggregate into JDBC data source
val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "true").toBoolean

// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
// by --files option of spark-submit or manually
val keytab = {
Expand Down Expand Up @@ -260,6 +263,7 @@ object JDBCOptions {
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
Expand Down
Loading