From beb75b89d9ad4abb95700d94db558f836fd638d3 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 3 Dec 2021 12:07:21 +0800 Subject: [PATCH 1/8] [SPARK-37527] Translate more standard aggregate functions for pushdown --- .../expressions/aggregate/AnyOrSome.java | 48 ++++++++++++++++++ .../connector/expressions/aggregate/Corr.java | 49 +++++++++++++++++++ .../expressions/aggregate/CovarPop.java | 49 +++++++++++++++++++ .../expressions/aggregate/CovarSamp.java | 49 +++++++++++++++++++ .../expressions/aggregate/Every.java | 41 ++++++++++++++++ .../expressions/aggregate/StddevPop.java | 41 ++++++++++++++++ .../expressions/aggregate/StddevSamp.java | 41 ++++++++++++++++ .../expressions/aggregate/VarPop.java | 41 ++++++++++++++++ .../expressions/aggregate/VarSamp.java | 41 ++++++++++++++++ .../datasources/DataSourceStrategy.scala | 30 +++++++++++- 10 files changed, 428 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/AnyOrSome.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Every.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarPop.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarSamp.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/AnyOrSome.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/AnyOrSome.java new file mode 100644 index 0000000000000..6186ede8e9375 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/AnyOrSome.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns true if at least one value of `column` is true. + * + * @since 3.3.0 + */ +@Evolving +public final class AnyOrSome implements AggregateFunc { + public static final String ANY = "ANY"; + public static final String SOME = "SOME"; + + private final NamedReference column; + private final String realName; + + public AnyOrSome(NamedReference column, String realName) { + this.column = column; + this.realName = realName; + } + + public NamedReference column() { return column; } + + @Override + public String toString() { return realName + "(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java new file mode 100644 index 0000000000000..e0e684f563e2d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns Pearson coefficient of correlation between a set of number pairs. + * + * @since 3.3.0 + */ +@Evolving +public final class Corr implements AggregateFunc { + private final NamedReference x; + private final NamedReference y; + + public Corr(NamedReference left, NamedReference right) { + this.x = left; + this.y = right; + } + + public NamedReference getX() { return x; } + + public NamedReference getY() { return y; } + + @Override + public String toString() { + return "CORR(" + x.describe() + "," + y.describe() + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java new file mode 100644 index 0000000000000..fb7afd873ae50 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the population covariance of a set of number pairs. + * + * @since 3.3.0 + */ +@Evolving +public final class CovarPop implements AggregateFunc { + private final NamedReference left; + private final NamedReference right; + + public CovarPop(NamedReference left, NamedReference right) { + this.left = left; + this.right = right; + } + + public NamedReference getLeft() { return left; } + + public NamedReference getRight() { return right; } + + @Override + public String toString() { + return "COVAR_POP(" + left.describe() + "," + right.describe() + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java new file mode 100644 index 0000000000000..d861b7a581a25 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the sample covariance of a set of number pairs. + * + * @since 3.3.0 + */ +@Evolving +public final class CovarSamp implements AggregateFunc { + private final NamedReference left; + private final NamedReference right; + + public CovarSamp(NamedReference left, NamedReference right) { + this.left = left; + this.right = right; + } + + public NamedReference getLeft() { return left; } + + public NamedReference getRight() { return right; } + + @Override + public String toString() { + return "COVAR_SAMP(" + left.describe() + "," + right.describe() + ")"; + } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Every.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Every.java new file mode 100644 index 0000000000000..ecfd446ffc22f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Every.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns true if all values of `column` are true. + * + * @since 3.3.0 + */ +@Evolving +public final class Every implements AggregateFunc { + private final NamedReference column; + + public Every(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "EVERY(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java new file mode 100644 index 0000000000000..5db97664081c8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the population standard deviation calculated from values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class StddevPop implements AggregateFunc { + private final NamedReference column; + + public StddevPop(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "STDDEV_POP(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java new file mode 100644 index 0000000000000..622d5de555cb2 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the sample standard deviation calculated from values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class StddevSamp implements AggregateFunc { + private final NamedReference column; + + public StddevSamp(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "STDDEV_SAMP(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarPop.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarPop.java new file mode 100644 index 0000000000000..568c89e2141ae --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarPop.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the population variance calculated from values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class VarPop implements AggregateFunc { + private final NamedReference column; + + public VarPop(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "VAR_POP(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarSamp.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarSamp.java new file mode 100644 index 0000000000000..07ce1538366c6 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/VarSamp.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the sample variance calculated from values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class VarSamp implements AggregateFunc { + private final NamedReference column; + + public VarSamp(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "VAR_SAMP(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6febbd590f246..0b023e4dfc518 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.FieldReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, AnyOrSome, Corr, Count, CountStar, CovarPop, CovarSamp, Every, Max, Min, StddevPop, StddevSamp, Sum, VarPop, VarSamp} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -717,8 +717,34 @@ object DataSourceStrategy Some(new Count(FieldReference(name), agg.isDistinct)) case _ => None } - case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => + case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => Some(new Sum(FieldReference(name), agg.isDistinct)) + case every @ aggregate.BoolAnd(PushableColumnWithoutNestedColumn(name)) + if every.nodeName == "every" => + Some(new Every(FieldReference(name))) + case any @ aggregate.BoolOr(PushableColumnWithoutNestedColumn(name)) + if any.nodeName == "any" => + Some(new AnyOrSome(FieldReference(name), AnyOrSome.ANY)) + case some @ aggregate.BoolOr(PushableColumnWithoutNestedColumn(name)) + if some.nodeName == "some" => + Some(new AnyOrSome(FieldReference(name), AnyOrSome.SOME)) + case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => + Some(new VarPop(FieldReference(name))) + case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new VarSamp(FieldReference(name))) + case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) => + Some(new StddevPop(FieldReference(name))) + case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new StddevSamp(FieldReference(name))) + case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new CovarPop(FieldReference(left), FieldReference(right))) + case aggregate.CovSample(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new CovarSamp(FieldReference(left), FieldReference(right))) + case aggregate.Corr(PushableColumnWithoutNestedColumn(x), + PushableColumnWithoutNestedColumn(y), _) => + Some(new Corr(FieldReference(x), FieldReference(y))) case _ => None } } else { From 11865ecd5faf125731fdc37af0466536f889b41e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 3 Dec 2021 20:14:42 +0800 Subject: [PATCH 2/8] Update code --- .../apache/spark/sql/connector/expressions/aggregate/Corr.java | 3 ++- .../spark/sql/connector/expressions/aggregate/StddevPop.java | 3 ++- .../spark/sql/connector/expressions/aggregate/StddevSamp.java | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java index e0e684f563e2d..67bc18e8b84d5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java @@ -21,7 +21,8 @@ import org.apache.spark.sql.connector.expressions.NamedReference; /** - * An aggregate function that returns Pearson coefficient of correlation between a set of number pairs. + * An aggregate function that returns Pearson coefficient of correlation between + * a set of number pairs. * * @since 3.3.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java index 5db97664081c8..4bb6d036eaad2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevPop.java @@ -21,7 +21,8 @@ import org.apache.spark.sql.connector.expressions.NamedReference; /** - * An aggregate function that returns the population standard deviation calculated from values of a group. + * An aggregate function that returns the population standard deviation calculated from + * values of a group. * * @since 3.3.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java index 622d5de555cb2..d13783c02f8d3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/StddevSamp.java @@ -21,7 +21,8 @@ import org.apache.spark.sql.connector.expressions.NamedReference; /** - * An aggregate function that returns the sample standard deviation calculated from values of a group. + * An aggregate function that returns the sample standard deviation calculated from + * values of a group. * * @since 3.3.0 */ From 0727d73a792d2eecaa7d8a0c6fa029cf40adbac1 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 6 Dec 2021 18:57:04 +0800 Subject: [PATCH 3/8] Update code --- .../expressions/aggregate/AnyOrSome.java | 48 ------------------- .../connector/expressions/aggregate/Corr.java | 14 +++--- .../expressions/aggregate/CovarPop.java | 4 +- .../expressions/aggregate/CovarSamp.java | 4 +- .../expressions/aggregate/Every.java | 41 ---------------- .../datasources/DataSourceStrategy.scala | 9 ---- 6 files changed, 11 insertions(+), 109 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/AnyOrSome.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Every.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/AnyOrSome.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/AnyOrSome.java deleted file mode 100644 index 6186ede8e9375..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/AnyOrSome.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.aggregate; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * An aggregate function that returns true if at least one value of `column` is true. - * - * @since 3.3.0 - */ -@Evolving -public final class AnyOrSome implements AggregateFunc { - public static final String ANY = "ANY"; - public static final String SOME = "SOME"; - - private final NamedReference column; - private final String realName; - - public AnyOrSome(NamedReference column, String realName) { - this.column = column; - this.realName = realName; - } - - public NamedReference column() { return column; } - - @Override - public String toString() { return realName + "(" + column.describe() + ")"; } - - @Override - public String describe() { return this.toString(); } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java index 67bc18e8b84d5..b4ba442f9a9ae 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Corr.java @@ -28,21 +28,21 @@ */ @Evolving public final class Corr implements AggregateFunc { - private final NamedReference x; - private final NamedReference y; + private final NamedReference left; + private final NamedReference right; public Corr(NamedReference left, NamedReference right) { - this.x = left; - this.y = right; + this.left = left; + this.right = right; } - public NamedReference getX() { return x; } + public NamedReference left() { return left; } - public NamedReference getY() { return y; } + public NamedReference right() { return right; } @Override public String toString() { - return "CORR(" + x.describe() + "," + y.describe() + ")"; + return "CORR(" + left.describe() + "," + right.describe() + ")"; } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java index fb7afd873ae50..e8d91563f7e12 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarPop.java @@ -35,9 +35,9 @@ public CovarPop(NamedReference left, NamedReference right) { this.right = right; } - public NamedReference getLeft() { return left; } + public NamedReference left() { return left; } - public NamedReference getRight() { return right; } + public NamedReference right() { return right; } @Override public String toString() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java index d861b7a581a25..dcaaf93da8b49 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CovarSamp.java @@ -35,9 +35,9 @@ public CovarSamp(NamedReference left, NamedReference right) { this.right = right; } - public NamedReference getLeft() { return left; } + public NamedReference left() { return left; } - public NamedReference getRight() { return right; } + public NamedReference right() { return right; } @Override public String toString() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Every.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Every.java deleted file mode 100644 index ecfd446ffc22f..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Every.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.aggregate; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * An aggregate function that returns true if all values of `column` are true. - * - * @since 3.3.0 - */ -@Evolving -public final class Every implements AggregateFunc { - private final NamedReference column; - - public Every(NamedReference column) { this.column = column; } - - public NamedReference column() { return column; } - - @Override - public String toString() { return "EVERY(" + column.describe() + ")"; } - - @Override - public String describe() { return this.toString(); } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 0b023e4dfc518..bdd7c2de5dce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -719,15 +719,6 @@ object DataSourceStrategy } case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => Some(new Sum(FieldReference(name), agg.isDistinct)) - case every @ aggregate.BoolAnd(PushableColumnWithoutNestedColumn(name)) - if every.nodeName == "every" => - Some(new Every(FieldReference(name))) - case any @ aggregate.BoolOr(PushableColumnWithoutNestedColumn(name)) - if any.nodeName == "any" => - Some(new AnyOrSome(FieldReference(name), AnyOrSome.ANY)) - case some @ aggregate.BoolOr(PushableColumnWithoutNestedColumn(name)) - if some.nodeName == "some" => - Some(new AnyOrSome(FieldReference(name), AnyOrSome.SOME)) case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => Some(new VarPop(FieldReference(name))) case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => From b2104de5280ea5cbf771a7cd7946b1a0ee326249 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 6 Dec 2021 18:58:35 +0800 Subject: [PATCH 4/8] Update code --- .../spark/sql/execution/datasources/DataSourceStrategy.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index bdd7c2de5dce9..9af9fb7b13022 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.FieldReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, AnyOrSome, Corr, Count, CountStar, CovarPop, CovarSamp, Every, Max, Min, StddevPop, StddevSamp, Sum, VarPop, VarSamp} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Corr, Count, CountStar, CovarPop, CovarSamp, Max, Min, StddevPop, StddevSamp, Sum, VarPop, VarSamp} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ From d92eecfcebbeb59768f3832c8e06f86ec42fee5b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 7 Dec 2021 16:22:24 +0800 Subject: [PATCH 5/8] Update code --- .../sql/jdbc/PostgresIntegrationSuite.scala | 87 +++++++++++++++++++ .../expressions/aggregate/Average.java | 41 +++++++++ .../datasources/DataSourceStrategy.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 8 +- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 21 +++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 5 +- .../spark/sql/jdbc/PostgresDialect.scala | 45 ++++++++++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 54 ++++++++++++ 8 files changed, 262 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Average.java diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 2562ee78ec5fc..ad16423f0624c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -25,6 +25,8 @@ import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.connector.expressions.aggregate.{CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType} import org.apache.spark.tags.DockerTest @@ -138,6 +140,19 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { "c0 money)").executeUpdate() conn.prepareStatement("INSERT INTO money_types VALUES " + "('$1,000.00')").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + + " bonus DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() } test("Type mapping for various types") { @@ -379,4 +394,76 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(row(0).length === 1) assert(row(0).getString(0) === "$1,000.00") } + + test("scan with aggregate push-down: VAR_POP VAR_SAMP") { + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM foo.test.employee where dept > 0" + + " group by DePt") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, output) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[VarPop]) + assert(aggregationExpressions(0).isInstanceOf[VarSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 10000d) + assert(row(0).getDouble(1) === 20000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(1).getDouble(1) === 5000d) + assert(row(1).getDouble(0) === 0d) + assert(row(1).getDouble(1) === null) + } + + test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP") { + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + + " where dept > 0 group by DePt") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, output) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[StddevPop]) + assert(aggregationExpressions(0).isInstanceOf[StddevSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 100d) + assert(row(0).getDouble(1) === 141.4213562373095d) + assert(row(1).getDouble(0) === 50d) + assert(row(1).getDouble(1) === 70.71067811865476d) + assert(row(1).getDouble(0) === 0d) + assert(row(1).getDouble(1) === null) + } + + test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { + val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus) FROM h2.test.employee" + + " where dept > 0 group by DePt") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, output) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[CovarPop]) + assert(aggregationExpressions(0).isInstanceOf[CovarSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 10000d) + assert(row(0).getDouble(1) === 20000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(1).getDouble(1) === 5000d) + assert(row(1).getDouble(0) === 0d) + assert(row(1).getDouble(1) === null) + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Average.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Average.java new file mode 100644 index 0000000000000..05c0b4c0635d8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Average.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the mean calculated from values of a group. + * + * @since 3.3.0 + */ +@Evolving +public final class Average implements AggregateFunc { + private final NamedReference column; + + public Average(NamedReference column) { this.column = column; } + + public NamedReference column() { return column; } + + @Override + public String toString() { return "AVG(" + column.describe() + ")"; } + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9af9fb7b13022..57da9b25b93d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.FieldReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Corr, Count, CountStar, CovarPop, CovarSamp, Max, Min, StddevPop, StddevSamp, Sum, VarPop, VarSamp} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Average, Corr, Count, CountStar, CovarPop, CovarSamp, Max, Min, StddevPop, StddevSamp, Sum, VarPop, VarSamp} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -719,6 +719,8 @@ object DataSourceStrategy } case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => Some(new Sum(FieldReference(name), agg.isDistinct)) + case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => + Some(new Average(FieldReference(name))) case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => Some(new VarPop(FieldReference(name))) case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 36923d1c0bd64..a9d907567fdfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation @@ -176,7 +176,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case max: aggregate.Max => max.copy(child = aggOutput(ordinal)) case min: aggregate.Min => min.copy(child = aggOutput(ordinal)) case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal)) + case avg: aggregate.Average => + aggregate.First(Cast(aggOutput(ordinal), avg.dataType), true) case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal)) + case _: aggregate.VariancePop | _: aggregate.VarianceSamp | + _: aggregate.StddevPop | _: aggregate.StddevSamp | + _: aggregate.CovPopulation | _: aggregate.CovSample => + aggregate.First(aggOutput(ordinal), true) case other => other } agg.copy(aggregateFunction = aggFunction) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 9c727957ffab8..3a99402f2f9e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -22,11 +22,32 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, StddevPop, StddevSamp, VarPop, VarSamp} private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case varPop: VarPop => + if (varPop.column.fieldNames.length != 1) return None + Some(s"VAR_POP(${quoteIdentifier(varPop.column.fieldNames.head)})") + case varSamp: VarSamp => + if (varSamp.column.fieldNames.length != 1) return None + Some(s"VAR_SAMP(${quoteIdentifier(varSamp.column.fieldNames.head)})") + case stddevPop: StddevPop => + if (stddevPop.column.fieldNames.length != 1) return None + Some(s"STDDEV_POP(${quoteIdentifier(stddevPop.column.fieldNames.head)})") + case stddevSamp: StddevSamp => + if (stddevSamp.column.fieldNames.length != 1) return None + Some(s"STDDEV_SAMP(${quoteIdentifier(stddevSamp.column.fieldNames.head)})") + case _ => None + } + ) + } + override def classifyException(message: String, e: Throwable): AnalysisException = { if (e.isInstanceOf[SQLException]) { // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 9a647e545d836..97c39be4c4a05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Average, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -208,6 +208,9 @@ abstract class JdbcDialect extends Serializable with Logging{ case max: Max => if (max.column.fieldNames.length != 1) return None Some(s"MAX(${quoteIdentifier(max.column.fieldNames.head)})") + case avg: Average => + if (avg.column.fieldNames.length != 1) return None + Some(s"AVG(${quoteIdentifier(avg.column.fieldNames.head)})") case count: Count => if (count.column.fieldNames.length != 1) return None val distinct = if (count.isDistinct) "DISTINCT " else "" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 317ae19ed914b..9faf73ae47447 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -139,6 +140,50 @@ private object PostgresDialect extends JdbcDialect { } } + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case varPop: VarPop => + if (varPop.column.fieldNames.length != 1) return None + Some(s"VAR_POP(${quoteIdentifier(varPop.column.fieldNames.head)})") + case varSamp: VarSamp => + if (varSamp.column.fieldNames.length != 1) return None + Some(s"VAR_SAMP(${quoteIdentifier(varSamp.column.fieldNames.head)})") + case stddevPop: StddevPop => + if (stddevPop.column.fieldNames.length != 1) return None + Some(s"STDDEV_POP(${quoteIdentifier(stddevPop.column.fieldNames.head)})") + case stddevSamp: StddevSamp => + if (stddevSamp.column.fieldNames.length != 1) return None + Some(s"STDDEV_SAMP(${quoteIdentifier(stddevSamp.column.fieldNames.head)})") + case covarPop: CovarPop => + if (covarPop.left.fieldNames.length != 1 && + covarPop.right.fieldNames.length != 1) { + return None + } + val compiledValue = + s""" + |COVAR_POP( + |${quoteIdentifier(covarPop.left.fieldNames.head)}, + |${quoteIdentifier(covarPop.right.fieldNames.head)}) + |""".stripMargin.replaceAll("\n", "") + Some(compiledValue) + case covarSamp: CovarSamp => + if (covarSamp.left.fieldNames.length != 1 && + covarSamp.right.fieldNames.length != 1) { + return None + } + val compiledValue = + s""" + |COVAR_SAMP( + |${quoteIdentifier(covarSamp.left.fieldNames.head)}, + |${quoteIdentifier(covarSamp.right.fieldNames.head)}) + |""".stripMargin.replaceAll("\n", "") + Some(compiledValue) + case _ => None + } + ) + } + // See https://www.postgresql.org/docs/12/sql-altertable.html override def getUpdateColumnTypeQuery( tableName: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index b87b4f6d86fd1..0765e9573db66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -316,6 +316,60 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } + test("scan with aggregate push-down: AVG with filter and group by") { + val df = sql("select AVG(SaLaRY) FROM h2.test.employee where dept > 0" + + " group by DePt") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [AVG(SALARY)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(9500), Row(11000), Row(12000))) + } + + test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") { + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" + + " group by DePt") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + } + + test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") { + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + + " where dept > 0 group by DePt") + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) + } + test("scan with aggregate push-down: MAX MIN with filter and group by") { val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt") From 651f40be11669ce3c68946da00bffce4715f494d Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 7 Dec 2021 17:35:56 +0800 Subject: [PATCH 6/8] Update code --- .../sql/jdbc/PostgresIntegrationSuite.scala | 30 +++++++++++++++---- .../v2/V2ScanRelationPushDown.scala | 3 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 14 +++++++++ .../spark/sql/jdbc/PostgresDialect.scala | 28 ++++------------- 4 files changed, 46 insertions(+), 29 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index ad16423f0624c..baa79104c100a 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -25,7 +25,7 @@ import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.connector.expressions.aggregate.{CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} +import org.apache.spark.sql.connector.expressions.aggregate.{Corr, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType} import org.apache.spark.tags.DockerTest @@ -142,7 +142,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { "('$1,000.00')").executeUpdate() conn.prepareStatement( "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + - " bonus DOUBLE)").executeUpdate() + " bonus double precision)").executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") .executeUpdate() conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") @@ -396,7 +396,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { } test("scan with aggregate push-down: VAR_POP VAR_SAMP") { - val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM foo.test.employee where dept > 0" + + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM test.employee where dept > 0" + " group by DePt") df.queryExecution.optimizedPlan.collect { case DataSourceV2ScanRelation(_, scan, output) => @@ -420,7 +420,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { } test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP") { - val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM test.employee" + " where dept > 0 group by DePt") df.queryExecution.optimizedPlan.collect { case DataSourceV2ScanRelation(_, scan, output) => @@ -444,7 +444,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { } test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { - val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus) FROM h2.test.employee" + + val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus) FROM test.employee" + " where dept > 0 group by DePt") df.queryExecution.optimizedPlan.collect { case DataSourceV2ScanRelation(_, scan, output) => @@ -466,4 +466,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(row(1).getDouble(0) === 0d) assert(row(1).getDouble(1) === null) } + + test("scan with aggregate push-down: CORR with filter and group by") { + val df = sql("select CORR(bonus, bonus) FROM test.employee where dept > 0" + + " group by DePt") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, output) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[Corr]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 1) + assert(row(0).getDouble(0) === 1d) + assert(row(1).getDouble(0) === 1d) + assert(row(1).getDouble(0) === null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index a9d907567fdfb..4ce47a15b24f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -181,7 +181,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal)) case _: aggregate.VariancePop | _: aggregate.VarianceSamp | _: aggregate.StddevPop | _: aggregate.StddevSamp | - _: aggregate.CovPopulation | _: aggregate.CovSample => + _: aggregate.CovPopulation | _: aggregate.CovSample | + _: aggregate.Corr => aggregate.First(aggOutput(ordinal), true) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 97c39be4c4a05..100f7285ebae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -227,6 +227,20 @@ abstract class JdbcDialect extends Serializable with Logging{ } } + protected def compileLeftRight( + left: NamedReference, right: NamedReference, funcName: String): Option[String] = { + if (left.fieldNames.length != 1 && right.fieldNames.length != 1) { + return None + } + val compiledValue = + s""" + |$funcName( + |${quoteIdentifier(left.fieldNames.head)}, + |${quoteIdentifier(right.fieldNames.head)}) + |""".stripMargin.replaceAll("\n", "") + Some(compiledValue) + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 9faf73ae47447..3ede234569353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} import java.util.Locale -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Corr, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -156,29 +156,11 @@ private object PostgresDialect extends JdbcDialect { if (stddevSamp.column.fieldNames.length != 1) return None Some(s"STDDEV_SAMP(${quoteIdentifier(stddevSamp.column.fieldNames.head)})") case covarPop: CovarPop => - if (covarPop.left.fieldNames.length != 1 && - covarPop.right.fieldNames.length != 1) { - return None - } - val compiledValue = - s""" - |COVAR_POP( - |${quoteIdentifier(covarPop.left.fieldNames.head)}, - |${quoteIdentifier(covarPop.right.fieldNames.head)}) - |""".stripMargin.replaceAll("\n", "") - Some(compiledValue) + compileLeftRight(covarPop.left, covarPop.right, "COVAR_POP") case covarSamp: CovarSamp => - if (covarSamp.left.fieldNames.length != 1 && - covarSamp.right.fieldNames.length != 1) { - return None - } - val compiledValue = - s""" - |COVAR_SAMP( - |${quoteIdentifier(covarSamp.left.fieldNames.head)}, - |${quoteIdentifier(covarSamp.right.fieldNames.head)}) - |""".stripMargin.replaceAll("\n", "") - Some(compiledValue) + compileLeftRight(covarSamp.left, covarSamp.right, "COVAR_SAMP") + case corr: Corr => + compileLeftRight(corr.left, corr.right, "CORR") case _ => None } ) From 2adba743c1ecf945d366fd2f14166e24e3df5bfc Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 7 Dec 2021 20:06:21 +0800 Subject: [PATCH 7/8] Update code --- .../sql/jdbc/PostgresIntegrationSuite.scala | 107 ----------------- .../jdbc/v2/PostgresIntegrationSuite.scala | 112 +++++++++++++++++- 2 files changed, 111 insertions(+), 108 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index baa79104c100a..2562ee78ec5fc 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -25,8 +25,6 @@ import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.connector.expressions.aggregate.{Corr, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType} import org.apache.spark.tags.DockerTest @@ -140,19 +138,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { "c0 money)").executeUpdate() conn.prepareStatement("INSERT INTO money_types VALUES " + "('$1,000.00')").executeUpdate() - conn.prepareStatement( - "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + - " bonus double precision)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") - .executeUpdate() } test("Type mapping for various types") { @@ -394,96 +379,4 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(row(0).length === 1) assert(row(0).getString(0) === "$1,000.00") } - - test("scan with aggregate push-down: VAR_POP VAR_SAMP") { - val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM test.employee where dept > 0" + - " group by DePt") - df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, output) => - assert(scan.isInstanceOf[V1ScanWrapper]) - val wrapper = scan.asInstanceOf[V1ScanWrapper] - assert(wrapper.pushedDownOperators.aggregation.isDefined) - val aggregationExpressions = - wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions(0).isInstanceOf[VarPop]) - assert(aggregationExpressions(0).isInstanceOf[VarSamp]) - } - val row = df.collect() - assert(row.length === 3) - assert(row(0).length === 2) - assert(row(0).getDouble(0) === 10000d) - assert(row(0).getDouble(1) === 20000d) - assert(row(1).getDouble(0) === 2500d) - assert(row(1).getDouble(1) === 5000d) - assert(row(1).getDouble(0) === 0d) - assert(row(1).getDouble(1) === null) - } - - test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP") { - val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM test.employee" + - " where dept > 0 group by DePt") - df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, output) => - assert(scan.isInstanceOf[V1ScanWrapper]) - val wrapper = scan.asInstanceOf[V1ScanWrapper] - assert(wrapper.pushedDownOperators.aggregation.isDefined) - val aggregationExpressions = - wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions(0).isInstanceOf[StddevPop]) - assert(aggregationExpressions(0).isInstanceOf[StddevSamp]) - } - val row = df.collect() - assert(row.length === 3) - assert(row(0).length === 2) - assert(row(0).getDouble(0) === 100d) - assert(row(0).getDouble(1) === 141.4213562373095d) - assert(row(1).getDouble(0) === 50d) - assert(row(1).getDouble(1) === 70.71067811865476d) - assert(row(1).getDouble(0) === 0d) - assert(row(1).getDouble(1) === null) - } - - test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { - val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus) FROM test.employee" + - " where dept > 0 group by DePt") - df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, output) => - assert(scan.isInstanceOf[V1ScanWrapper]) - val wrapper = scan.asInstanceOf[V1ScanWrapper] - assert(wrapper.pushedDownOperators.aggregation.isDefined) - val aggregationExpressions = - wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions(0).isInstanceOf[CovarPop]) - assert(aggregationExpressions(0).isInstanceOf[CovarSamp]) - } - val row = df.collect() - assert(row.length === 3) - assert(row(0).length === 2) - assert(row(0).getDouble(0) === 10000d) - assert(row(0).getDouble(1) === 20000d) - assert(row(1).getDouble(0) === 2500d) - assert(row(1).getDouble(1) === 5000d) - assert(row(1).getDouble(0) === 0d) - assert(row(1).getDouble(1) === null) - } - - test("scan with aggregate push-down: CORR with filter and group by") { - val df = sql("select CORR(bonus, bonus) FROM test.employee where dept > 0" + - " group by DePt") - df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, output) => - assert(scan.isInstanceOf[V1ScanWrapper]) - val wrapper = scan.asInstanceOf[V1ScanWrapper] - assert(wrapper.pushedDownOperators.aggregation.isDefined) - val aggregationExpressions = - wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() - assert(aggregationExpressions(0).isInstanceOf[Corr]) - } - val row = df.collect() - assert(row.length === 3) - assert(row(0).length === 1) - assert(row(0).getDouble(0) === 1d) - assert(row(1).getDouble(0) === 1d) - assert(row(1).getDouble(0) === null) - } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 1a1a592d00bca..bb5df196065b8 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -21,6 +21,8 @@ import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.expressions.aggregate.{Corr, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} import org.apache.spark.sql.types._ @@ -50,9 +52,25 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes .set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.postgresql.pushDownTableSample", "true") + .set("spark.sql.catalog.postgresql.pushDownAggregate", "true") .set("spark.sql.catalog.postgresql.pushDownLimit", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE SCHEMA \"test\"").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + + " bonus double precision)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -80,4 +98,96 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes } override def supportsTableSample: Boolean = true + + test("scan with aggregate push-down: VAR_POP VAR_SAMP") { + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM postgresql.test.employee" + + " where dept > 0 group by DePt order by dept") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, output) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[VarPop]) + assert(aggregationExpressions(1).isInstanceOf[VarSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 10000d) + assert(row(0).getDouble(1) === 20000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(1).getDouble(1) === 5000d) + assert(row(2).getDouble(0) === 0d) + assert(row(2).isNullAt(1)) + } + + test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP") { + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM postgresql.test.employee" + + " where dept > 0 group by DePt order by dept") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, output) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[StddevPop]) + assert(aggregationExpressions(1).isInstanceOf[StddevSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 100d) + assert(row(0).getDouble(1) === 141.4213562373095d) + assert(row(1).getDouble(0) === 50d) + assert(row(1).getDouble(1) === 70.71067811865476d) + assert(row(2).getDouble(0) === 0d) + assert(row(2).isNullAt(1)) + } + + test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { + val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus) FROM" + + " postgresql.test.employee where dept > 0 group by DePt order by dept") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, output) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[CovarPop]) + assert(aggregationExpressions(1).isInstanceOf[CovarSamp]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 2) + assert(row(0).getDouble(0) === 10000d) + assert(row(0).getDouble(1) === 20000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(1).getDouble(1) === 5000d) + assert(row(2).getDouble(0) === 0d) + assert(row(2).isNullAt(1)) + } + + test("scan with aggregate push-down: CORR with filter and group by") { + val df = sql("select CORR(bonus, bonus) FROM postgresql.test.employee where dept > 0" + + " group by DePt order by dept") + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, output) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions(0).isInstanceOf[Corr]) + } + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 1) + assert(row(0).getDouble(0) === 1d) + assert(row(1).getDouble(0) === 1d) + assert(row(2).isNullAt(0)) + } } From b9c7d9645c6e30752bc786fbf1519ce0fd44df0d Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 7 Dec 2021 20:11:53 +0800 Subject: [PATCH 8/8] Update code --- .../spark/sql/jdbc/v2/PostgresIntegrationSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index bb5df196065b8..79ef4298eaa99 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -103,7 +103,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM postgresql.test.employee" + " where dept > 0 group by DePt order by dept") df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, output) => + case DataSourceV2ScanRelation(_, scan, _) => assert(scan.isInstanceOf[V1ScanWrapper]) val wrapper = scan.asInstanceOf[V1ScanWrapper] assert(wrapper.pushedDownOperators.aggregation.isDefined) @@ -127,7 +127,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM postgresql.test.employee" + " where dept > 0 group by DePt order by dept") df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, output) => + case DataSourceV2ScanRelation(_, scan, _) => assert(scan.isInstanceOf[V1ScanWrapper]) val wrapper = scan.asInstanceOf[V1ScanWrapper] assert(wrapper.pushedDownOperators.aggregation.isDefined) @@ -151,7 +151,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus) FROM" + " postgresql.test.employee where dept > 0 group by DePt order by dept") df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, output) => + case DataSourceV2ScanRelation(_, scan, _) => assert(scan.isInstanceOf[V1ScanWrapper]) val wrapper = scan.asInstanceOf[V1ScanWrapper] assert(wrapper.pushedDownOperators.aggregation.isDefined) @@ -175,7 +175,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes val df = sql("select CORR(bonus, bonus) FROM postgresql.test.employee where dept > 0" + " group by DePt order by dept") df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, output) => + case DataSourceV2ScanRelation(_, scan, _) => assert(scan.isInstanceOf[V1ScanWrapper]) val wrapper = scan.asInstanceOf[V1ScanWrapper] assert(wrapper.pushedDownOperators.aggregation.isDefined)