diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 1a1a592d00bca..79ef4298eaa99 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -21,6 +21,8 @@ import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.expressions.aggregate.{Corr, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} import org.apache.spark.sql.types._ @@ -50,9 +52,25 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes .set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.postgresql.pushDownTableSample", "true") + .set("spark.sql.catalog.postgresql.pushDownAggregate", "true") .set("spark.sql.catalog.postgresql.pushDownLimit", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE SCHEMA \"test\"").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + + " bonus double precision)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -80,4 +98,96 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes } override def supportsTableSample: Boolean = true + + test("scan with aggregate push-down: VAR_POP VAR_SAMP") { + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM postgresql.test.employee" + + " where dept > 0 group by DePt order by dept") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[VarPop]) + assert(aggregationExpressions(1).isInstanceOf[VarSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 10000d) + assert(row(0).getDouble(1) === 20000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(1).getDouble(1) === 5000d) + assert(row(2).getDouble(0) === 0d) + assert(row(2).isNullAt(1)) + } + + test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP") { + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM postgresql.test.employee" + + " where dept > 0 group by DePt order by dept") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[StddevPop]) + assert(aggregationExpressions(1).isInstanceOf[StddevSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 100d) + assert(row(0).getDouble(1) === 141.4213562373095d) + assert(row(1).getDouble(0) === 50d) + assert(row(1).getDouble(1) === 70.71067811865476d) + assert(row(2).getDouble(0) === 0d) + assert(row(2).isNullAt(1)) + } + + test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { + val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus) FROM" + + " postgresql.test.employee where dept > 0 group by DePt order by dept") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[CovarPop]) + assert(aggregationExpressions(1).isInstanceOf[CovarSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 10000d) + assert(row(0).getDouble(1) === 20000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(1).getDouble(1) === 5000d) + assert(row(2).getDouble(0) === 0d) + assert(row(2).isNullAt(1)) + } + + test("scan with aggregate push-down: CORR with filter and group by") { + val df = sql("select CORR(bonus, bonus) FROM postgresql.test.employee where dept > 0" + + " group by DePt order by dept") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[Corr]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 1) + assert(row(0).getDouble(0) === 1d) + assert(row(1).getDouble(0) === 1d) + assert(row(2).isNullAt(0)) + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Average.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Average.java new file mode 100644 index 0000000000000..05c0b4c0635d8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Average.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the mean calculated from values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class Average implements AggregateFunc { + private final NamedReference column; + + public Average(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "AVG(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java new file mode 100644 index 0000000000000..b4ba442f9a9ae --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns Pearson coefficient of correlation between + * a set of number pairs. + * + * @since 3.3.0 + */ +@Evolving +public final class Corr implements AggregateFunc { + private final NamedReference left; + private final NamedReference right; + + public Corr(NamedReference left, NamedReference right) { + this.left = left; + this.right = right; + } + + public NamedReference left() { return left; } + + public NamedReference right() { return right; } + + @Override + public String toString() { + return "CORR(" + left.describe() + "," + right.describe() + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java new file mode 100644 index 0000000000000..e8d91563f7e12 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the population covariance of a set of number pairs. + * + * @since 3.3.0 + */ +@Evolving +public final class CovarPop implements AggregateFunc { + private final NamedReference left; + private final NamedReference right; + + public CovarPop(NamedReference left, NamedReference right) { + this.left = left; + this.right = right; + } + + public NamedReference left() { return left; } + + public NamedReference right() { return right; } + + @Override + public String toString() { + return "COVAR_POP(" + left.describe() + "," + right.describe() + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java new file mode 100644 index 0000000000000..dcaaf93da8b49 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the sample covariance of a set of number pairs. + * + * @since 3.3.0 + */ +@Evolving +public final class CovarSamp implements AggregateFunc { + private final NamedReference left; + private final NamedReference right; + + public CovarSamp(NamedReference left, NamedReference right) { + this.left = left; + this.right = right; + } + + public NamedReference left() { return left; } + + public NamedReference right() { return right; } + + @Override + public String toString() { + return "COVAR_SAMP(" + left.describe() + "," + right.describe() + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java new file mode 100644 index 0000000000000..4bb6d036eaad2 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the population standard deviation calculated from + * values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class StddevPop implements AggregateFunc { + private final NamedReference column; + + public StddevPop(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "STDDEV_POP(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java new file mode 100644 index 0000000000000..d13783c02f8d3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the sample standard deviation calculated from + * values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class StddevSamp implements AggregateFunc { + private final NamedReference column; + + public StddevSamp(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "STDDEV_SAMP(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarPop.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarPop.java new file mode 100644 index 0000000000000..568c89e2141ae --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarPop.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the population variance calculated from values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class VarPop implements AggregateFunc { + private final NamedReference column; + + public VarPop(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "VAR_POP(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarSamp.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarSamp.java new file mode 100644 index 0000000000000..07ce1538366c6 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarSamp.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the sample variance calculated from values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class VarSamp implements AggregateFunc { + private final NamedReference column; + + public VarSamp(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "VAR_SAMP(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6febbd590f246..57da9b25b93d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.FieldReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Average, Corr, Count, CountStar, CovarPop, CovarSamp, Max, Min, StddevPop, StddevSamp, Sum, VarPop, VarSamp} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -717,8 +717,27 @@ object DataSourceStrategy Some(new Count(FieldReference(name), agg.isDistinct)) case _ => None } - case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => + case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => Some(new Sum(FieldReference(name), agg.isDistinct)) + case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => + Some(new Average(FieldReference(name))) + case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => + Some(new VarPop(FieldReference(name))) + case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new VarSamp(FieldReference(name))) + case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) => + Some(new StddevPop(FieldReference(name))) + case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new StddevSamp(FieldReference(name))) + case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new CovarPop(FieldReference(left), FieldReference(right))) + case aggregate.CovSample(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new CovarSamp(FieldReference(left), FieldReference(right))) + case aggregate.Corr(PushableColumnWithoutNestedColumn(x), + PushableColumnWithoutNestedColumn(y), _) => + Some(new Corr(FieldReference(x), FieldReference(y))) case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 36923d1c0bd64..4ce47a15b24f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation @@ -176,7 +176,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case max: aggregate.Max => max.copy(child = aggOutput(ordinal)) case min: aggregate.Min => min.copy(child = aggOutput(ordinal)) case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal)) + case avg: aggregate.Average => + aggregate.First(Cast(aggOutput(ordinal), avg.dataType), true) case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal)) + case _: aggregate.VariancePop | _: aggregate.VarianceSamp | + _: aggregate.StddevPop | _: aggregate.StddevSamp | + _: aggregate.CovPopulation | _: aggregate.CovSample | + _: aggregate.Corr => + aggregate.First(aggOutput(ordinal), true) case other => other } agg.copy(aggregateFunction = aggFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 9c727957ffab8..3a99402f2f9e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -22,11 +22,32 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, StddevPop, StddevSamp, VarPop, VarSamp} private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case varPop: VarPop => + if (varPop.column.fieldNames.length != 1) return None + Some(s"VAR_POP(${quoteIdentifier(varPop.column.fieldNames.head)})") + case varSamp: VarSamp => + if (varSamp.column.fieldNames.length != 1) return None + Some(s"VAR_SAMP(${quoteIdentifier(varSamp.column.fieldNames.head)})") + case stddevPop: StddevPop => + if (stddevPop.column.fieldNames.length != 1) return None + Some(s"STDDEV_POP(${quoteIdentifier(stddevPop.column.fieldNames.head)})") + case stddevSamp: StddevSamp => + if (stddevSamp.column.fieldNames.length != 1) return None + Some(s"STDDEV_SAMP(${quoteIdentifier(stddevSamp.column.fieldNames.head)})") + case _ => None + } + ) + } + override def classifyException(message: String, e: Throwable): AnalysisException = { if (e.isInstanceOf[SQLException]) { // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 9a647e545d836..100f7285ebae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Average, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -208,6 +208,9 @@ abstract class JdbcDialect extends Serializable with Logging{ case max: Max => if (max.column.fieldNames.length != 1) return None Some(s"MAX(${quoteIdentifier(max.column.fieldNames.head)})") + case avg: Average => + if (avg.column.fieldNames.length != 1) return None + Some(s"AVG(${quoteIdentifier(avg.column.fieldNames.head)})") case count: Count => if (count.column.fieldNames.length != 1) return None val distinct = if (count.isDistinct) "DISTINCT " else "" @@ -224,6 +227,20 @@ abstract class JdbcDialect extends Serializable with Logging{ } } + protected def compileLeftRight( + left: NamedReference, right: NamedReference, funcName: String): Option[String] = { + if (left.fieldNames.length != 1 && right.fieldNames.length != 1) { + return None + } + val compiledValue = + s""" + |$funcName( + |${quoteIdentifier(left.fieldNames.head)}, + |${quoteIdentifier(right.fieldNames.head)}) + |""".stripMargin.replaceAll("\n", "") + Some(compiledValue) + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 317ae19ed914b..3ede234569353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Corr, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -139,6 +140,32 @@ private object PostgresDialect extends JdbcDialect { } } + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case varPop: VarPop => + if (varPop.column.fieldNames.length != 1) return None + Some(s"VAR_POP(${quoteIdentifier(varPop.column.fieldNames.head)})") + case varSamp: VarSamp => + if (varSamp.column.fieldNames.length != 1) return None + Some(s"VAR_SAMP(${quoteIdentifier(varSamp.column.fieldNames.head)})") + case stddevPop: StddevPop => + if (stddevPop.column.fieldNames.length != 1) return None + Some(s"STDDEV_POP(${quoteIdentifier(stddevPop.column.fieldNames.head)})") + case stddevSamp: StddevSamp => + if (stddevSamp.column.fieldNames.length != 1) return None + Some(s"STDDEV_SAMP(${quoteIdentifier(stddevSamp.column.fieldNames.head)})") + case covarPop: CovarPop => + compileLeftRight(covarPop.left, covarPop.right, "COVAR_POP") + case covarSamp: CovarSamp => + compileLeftRight(covarSamp.left, covarSamp.right, "COVAR_SAMP") + case corr: Corr => + compileLeftRight(corr.left, corr.right, "CORR") + case _ => None + } + ) + } + // See https://www.postgresql.org/docs/12/sql-altertable.html override def getUpdateColumnTypeQuery( tableName: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index b87b4f6d86fd1..0765e9573db66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -316,6 +316,60 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } + test("scan with aggregate push-down: AVG with filter and group by") { + val df = sql("select AVG(SaLaRY) FROM h2.test.employee where dept > 0" + + " group by DePt") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [AVG(SALARY)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(9500), Row(11000), Row(12000))) + } + + test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") { + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" + + " group by DePt") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + } + + test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") { + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + + " where dept > 0 group by DePt") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) + } + test("scan with aggregate push-down: MAX MIN with filter and group by") { val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt")