diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index e0d95cfaafbb0..32615e201643b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -32,6 +32,13 @@ * The currently supported SQL aggregate functions: *
    *
  1. AVG(input1)
    Since 3.3.0
  2. + *
  3. VAR_POP(input1)
    Since 3.3.0
  4. + *
  5. VAR_SAMP(input1)
    Since 3.3.0
  6. + *
  7. STDDEV_POP(input1)
    Since 3.3.0
  8. + *
  9. STDDEV_SAMP(input1)
    Since 3.3.0
  10. + *
  11. COVAR_POP(input1, input2)
    Since 3.3.0
  12. + *
  13. COVAR_SAMP(input1, input2)
    Since 3.3.0
  14. + *
  15. CORR(input1, input2)
    Since 3.3.0
  16. *
* * @since 3.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index fa5429678c1db..564937feb608d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -721,6 +721,26 @@ object DataSourceStrategy Some(new Sum(FieldReference(name), agg.isDistinct)) case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("VAR_POP", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("VAR_SAMP", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("STDDEV_POP", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct, + Array(FieldReference(left), FieldReference(right)))) + case aggregate.CovSample(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct, + Array(FieldReference(left), FieldReference(right)))) + case aggregate.Corr(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("CORR", agg.isDistinct, + Array(FieldReference(left), FieldReference(right)))) case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 9c727957ffab8..087c3573fbdbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -22,11 +22,36 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def classifyException(message: String, e: Throwable): AnalysisException = { if (e.isInstanceOf[SQLException]) { // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 9d37a85a2c916..3ae9d0322c6ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -698,6 +698,66 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1))) } + test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") { + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" + + " group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + } + + test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") { + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + + " where dept > 0 group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) + } + + test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { + val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + + " FROM h2.test.employee where dept > 0 group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + } + + test("scan with aggregate push-down: CORR with filter and group by") { + val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" + + " group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) + } + test("scan with aggregate push-down: aggregate over alias NOT push down") { val cols = Seq("a", "b", "c", "d") val df1 = sql("select * from h2.test.employee").toDF(cols: _*)