diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index e0d95cfaafbb0..32615e201643b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -32,6 +32,13 @@
* The currently supported SQL aggregate functions:
*
* AVG(input1)
Since 3.3.0
+ * VAR_POP(input1)
Since 3.3.0
+ * VAR_SAMP(input1)
Since 3.3.0
+ * STDDEV_POP(input1)
Since 3.3.0
+ * STDDEV_SAMP(input1)
Since 3.3.0
+ * COVAR_POP(input1, input2)
Since 3.3.0
+ * COVAR_SAMP(input1, input2)
Since 3.3.0
+ * CORR(input1, input2)
Since 3.3.0
*
*
* @since 3.3.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index fa5429678c1db..564937feb608d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -721,6 +721,26 @@ object DataSourceStrategy
Some(new Sum(FieldReference(name), agg.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new GeneralAggregateFunc("VAR_POP", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new GeneralAggregateFunc("VAR_SAMP", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new GeneralAggregateFunc("STDDEV_POP", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new GeneralAggregateFunc("STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left),
+ PushableColumnWithoutNestedColumn(right), _) =>
+ Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct,
+ Array(FieldReference(left), FieldReference(right))))
+ case aggregate.CovSample(PushableColumnWithoutNestedColumn(left),
+ PushableColumnWithoutNestedColumn(right), _) =>
+ Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct,
+ Array(FieldReference(left), FieldReference(right))))
+ case aggregate.Corr(PushableColumnWithoutNestedColumn(left),
+ PushableColumnWithoutNestedColumn(right), _) =>
+ Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
+ Array(FieldReference(left), FieldReference(right))))
case _ => None
}
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 9c727957ffab8..087c3573fbdbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -22,11 +22,36 @@ import java.util.Locale
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
private object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
+ assert(f.inputs().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR_POP($distinct${f.inputs().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
+ assert(f.inputs().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR_SAMP($distinct${f.inputs().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
+ assert(f.inputs().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_POP($distinct${f.inputs().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
+ assert(f.inputs().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
+ case _ => None
+ }
+ )
+ }
+
override def classifyException(message: String, e: Throwable): AnalysisException = {
if (e.isInstanceOf[SQLException]) {
// Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 9d37a85a2c916..3ae9d0322c6ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -698,6 +698,66 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1)))
}
+ test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") {
+ val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" +
+ " group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " +
+ "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
+ "PushedGroupByColumns: [DEPT]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
+ }
+
+ test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") {
+ val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" +
+ " where dept > 0 group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " +
+ "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
+ "PushedGroupByColumns: [DEPT]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null)))
+ }
+
+ test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") {
+ val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" +
+ " FROM h2.test.employee where dept > 0 group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df, false)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
+ }
+
+ test("scan with aggregate push-down: CORR with filter and group by") {
+ val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" +
+ " group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df, false)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(1d), Row(1d), Row(null)))
+ }
+
test("scan with aggregate push-down: aggregate over alias NOT push down") {
val cols = Seq("a", "b", "c", "d")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)