From 73b5da34b4fbd4d192f0b9a9f6a23ca805b21207 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 11 Sep 2021 10:12:21 -0700 Subject: [PATCH 01/53] [SPARK-36556][SQL] Add DSV2 filters Co-Authored-By: DB Tsai d_tsaiapple.com Co-Authored-By: Huaxin Gao huaxin_gaoapple.com ### What changes were proposed in this pull request? Add DSV2 Filters and use these in V2 codepath. ### Why are the changes needed? The motivation of adding DSV2 filters: 1. The values in V1 filters are Scala types. When translating catalyst `Expression` to V1 filers, we have to call `convertToScala` to convert from Catalyst types used internally in rows to standard Scala types, and later convert Scala types back to Catalyst types. This is very inefficient. In V2 filters, we use `Expression` for filter values, so the conversion from Catalyst types to Scala types and Scala types back to Catalyst types are avoided. 2. Improve nested column filter support. 3. Make the filters work better with the rest of the DSV2 APIs. ### Does this PR introduce _any_ user-facing change? Yes. The new V2 filters ### How was this patch tested? new test Closes #33803 from huaxingao/filter. Lead-authored-by: Huaxin Gao Co-authored-by: DB Tsai Signed-off-by: Liang-Chi Hsieh --- .../expressions/filter/AlwaysFalse.java | 50 +++++ .../expressions/filter/AlwaysTrue.java | 50 +++++ .../sql/connector/expressions/filter/And.java | 39 ++++ .../expressions/filter/BinaryComparison.java | 60 ++++++ .../expressions/filter/BinaryFilter.java | 65 ++++++ .../expressions/filter/EqualNullSafe.java | 40 ++++ .../connector/expressions/filter/EqualTo.java | 39 ++++ .../connector/expressions/filter/Filter.java | 41 ++++ .../expressions/filter/GreaterThan.java | 39 ++++ .../filter/GreaterThanOrEqual.java | 39 ++++ .../sql/connector/expressions/filter/In.java | 76 +++++++ .../expressions/filter/IsNotNull.java | 58 +++++ .../connector/expressions/filter/IsNull.java | 58 +++++ .../expressions/filter/LessThan.java | 39 ++++ .../expressions/filter/LessThanOrEqual.java | 39 ++++ .../sql/connector/expressions/filter/Not.java | 56 +++++ .../sql/connector/expressions/filter/Or.java | 39 ++++ .../expressions/filter/StringContains.java | 39 ++++ .../expressions/filter/StringEndsWith.java | 39 ++++ .../expressions/filter/StringPredicate.java | 60 ++++++ .../expressions/filter/StringStartsWith.java | 41 ++++ .../datasources/v2/V2FiltersSuite.scala | 204 ++++++++++++++++++ 22 files changed, 1210 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryFilter.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java new file mode 100644 index 0000000000000..72ed83f86df6d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that always evaluates to {@code false}. + * + * @since 3.3.0 + */ +@Evolving +public final class AlwaysFalse extends Filter { + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hash(); + } + + @Override + public String toString() { return "FALSE"; } + + @Override + public NamedReference[] references() { return EMPTY_REFERENCE; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java new file mode 100644 index 0000000000000..b6d39c3f64a77 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that always evaluates to {@code true}. + * + * @since 3.3.0 + */ +@Evolving +public final class AlwaysTrue extends Filter { + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hash(); + } + + @Override + public String toString() { return "TRUE"; } + + @Override + public NamedReference[] references() { return EMPTY_REFERENCE; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java new file mode 100644 index 0000000000000..e0b8b13acb158 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; + +/** + * A filter that evaluates to {@code true} iff both {@code left} and {@code right} evaluate to + * {@code true}. + * + * @since 3.3.0 + */ +@Evolving +public final class And extends BinaryFilter { + + public And(Filter left, Filter right) { + super(left, right); + } + + @Override + public String toString() { + return String.format("(%s) AND (%s)", left.describe(), right.describe()); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java new file mode 100644 index 0000000000000..0ae6e5af3ca1a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * Base class for {@link EqualNullSafe}, {@link EqualTo}, {@link GreaterThan}, + * {@link GreaterThanOrEqual}, {@link LessThan}, {@link LessThanOrEqual} + * + * @since 3.3.0 + */ +@Evolving +abstract class BinaryComparison extends Filter { + protected final NamedReference column; + protected final Literal value; + + protected BinaryComparison(NamedReference column, Literal value) { + this.column = column; + this.value = value; + } + + public NamedReference column() { return column; } + public Literal value() { return value; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BinaryComparison that = (BinaryComparison) o; + return Objects.equals(column, that.column) && Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(column, value); + } + + @Override + public NamedReference[] references() { return new NamedReference[] { column }; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryFilter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryFilter.java new file mode 100644 index 0000000000000..ac4b9f281e9ca --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryFilter.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * Base class for {@link And}, {@link Or} + * + * @since 3.3.0 + */ +@Evolving +abstract class BinaryFilter extends Filter { + protected final Filter left; + protected final Filter right; + + protected BinaryFilter(Filter left, Filter right) { + this.left = left; + this.right = right; + } + + public Filter left() { return left; } + public Filter right() { return right; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BinaryFilter and = (BinaryFilter) o; + return Objects.equals(left, and.left) && Objects.equals(right, and.right); + } + + @Override + public int hashCode() { + return Objects.hash(left, right); + } + + @Override + public NamedReference[] references() { + NamedReference[] refLeft = left.references(); + NamedReference[] refRight = right.references(); + NamedReference[] arr = new NamedReference[refLeft.length + refRight.length]; + System.arraycopy(refLeft, 0, arr, 0, refLeft.length); + System.arraycopy(refRight, 0, arr, refLeft.length, refRight.length); + return arr; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java new file mode 100644 index 0000000000000..34b529194e075 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * Performs equality comparison, similar to {@link EqualTo}. However, this differs from + * {@link EqualTo} in that it returns {@code true} (rather than NULL) if both inputs are NULL, + * and {@code false} (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 3.3.0 + */ +@Evolving +public final class EqualNullSafe extends BinaryComparison { + + public EqualNullSafe(NamedReference column, Literal value) { + super(column, value); + } + + @Override + public String toString() { return this.column.describe() + " <=> " + value.describe(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java new file mode 100644 index 0000000000000..b9c4fe053b83c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value + * equal to {@code value}. + * + * @since 3.3.0 + */ +@Evolving +public final class EqualTo extends BinaryComparison { + + public EqualTo(NamedReference column, Literal value) { + super(column, value); + } + + @Override + public String toString() { return column.describe() + " = " + value.describe(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java new file mode 100644 index 0000000000000..852837496a103 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * Filter base class + * + * @since 3.3.0 + */ +@Evolving +public abstract class Filter implements Expression { + + protected static final NamedReference[] EMPTY_REFERENCE = new NamedReference[0]; + + /** + * Returns list of columns that are referenced by this filter. + */ + public abstract NamedReference[] references(); + + @Override + public String describe() { return this.toString(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java new file mode 100644 index 0000000000000..a3374f359ea29 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value + * greater than {@code value}. + * + * @since 3.3.0 + */ +@Evolving +public final class GreaterThan extends BinaryComparison { + + public GreaterThan(NamedReference column, Literal value) { + super(column, value); + } + + @Override + public String toString() { return column.describe() + " > " + value.describe(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java new file mode 100644 index 0000000000000..4ee921014da41 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value + * greater than or equal to {@code value}. + * + * @since 3.3.0 + */ +@Evolving +public final class GreaterThanOrEqual extends BinaryComparison { + + public GreaterThanOrEqual(NamedReference column, Literal value) { + super(column, value); + } + + @Override + public String toString() { return column.describe() + " >= " + value.describe(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java new file mode 100644 index 0000000000000..8d6490b8984fd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Arrays; +import java.util.Objects; +import java.util.stream.Collectors; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to one of the + * {@code values} in the array. + * + * @since 3.3.0 + */ +@Evolving +public final class In extends Filter { + static final int MAX_LEN_TO_PRINT = 50; + private final NamedReference column; + private final Literal[] values; + + public In(NamedReference column, Literal[] values) { + this.column = column; + this.values = values; + } + + public NamedReference column() { return column; } + public Literal[] values() { return values; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + In in = (In) o; + return Objects.equals(column, in.column) && values.length == in.values.length + && Arrays.asList(values).containsAll(Arrays.asList(in.values)); + } + + @Override + public int hashCode() { + int result = Objects.hash(column); + result = 31 * result + Arrays.hashCode(values); + return result; + } + + @Override + public String toString() { + String res = Arrays.stream(values).limit((MAX_LEN_TO_PRINT)).map(Literal::describe) + .collect(Collectors.joining(", ")); + if (values.length > MAX_LEN_TO_PRINT) { + res += "..."; + } + return column.describe() + " IN (" + res + ")"; + } + + @Override + public NamedReference[] references() { return new NamedReference[] { column }; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java new file mode 100644 index 0000000000000..2cf000e99878e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to a non-null value. + * + * @since 3.3.0 + */ +@Evolving +public final class IsNotNull extends Filter { + private final NamedReference column; + + public IsNotNull(NamedReference column) { + this.column = column; + } + + public NamedReference column() { return column; } + + @Override + public String toString() { return column.describe() + " IS NOT NULL"; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IsNotNull isNotNull = (IsNotNull) o; + return Objects.equals(column, isNotNull.column); + } + + @Override + public int hashCode() { + return Objects.hash(column); + } + + @Override + public NamedReference[] references() { return new NamedReference[] { column }; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java new file mode 100644 index 0000000000000..1cd497c02242e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to null. + * + * @since 3.3.0 + */ +@Evolving +public final class IsNull extends Filter { + private final NamedReference column; + + public IsNull(NamedReference column) { + this.column = column; + } + + public NamedReference column() { return column; } + + @Override + public String toString() { return column.describe() + " IS NULL"; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IsNull isNull = (IsNull) o; + return Objects.equals(column, isNull.column); + } + + @Override + public int hashCode() { + return Objects.hash(column); + } + + @Override + public NamedReference[] references() { return new NamedReference[] { column }; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java new file mode 100644 index 0000000000000..9fa5cfb87f527 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value + * less than {@code value}. + * + * @since 3.3.0 + */ +@Evolving +public final class LessThan extends BinaryComparison { + + public LessThan(NamedReference column, Literal value) { + super(column, value); + } + + @Override + public String toString() { return column.describe() + " < " + value.describe(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java new file mode 100644 index 0000000000000..a41b3c8045d5a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value + * less than or equal to {@code value}. + * + * @since 3.3.0 + */ +@Evolving +public final class LessThanOrEqual extends BinaryComparison { + + public LessThanOrEqual(NamedReference column, Literal value) { + super(column, value); + } + + @Override + public String toString() { return column.describe() + " <= " + value.describe(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java new file mode 100644 index 0000000000000..69746f59ee933 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * A filter that evaluates to {@code true} iff {@code child} is evaluated to {@code false}. + * + * @since 3.3.0 + */ +@Evolving +public final class Not extends Filter { + private final Filter child; + + public Not(Filter child) { this.child = child; } + + public Filter child() { return child; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Not not = (Not) o; + return Objects.equals(child, not.child); + } + + @Override + public int hashCode() { + return Objects.hash(child); + } + + @Override + public String toString() { return "NOT (" + child.describe() + ")"; } + + @Override + public NamedReference[] references() { return child.references(); } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java new file mode 100644 index 0000000000000..baa33d849feef --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; + +/** + * A filter that evaluates to {@code true} iff at least one of {@code left} or {@code right} + * evaluates to {@code true}. + * + * @since 3.3.0 + */ +@Evolving +public final class Or extends BinaryFilter { + + public Or(Filter left, Filter right) { + super(left, right); + } + + @Override + public String toString() { + return String.format("(%s) OR (%s)", left.describe(), right.describe()); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java new file mode 100644 index 0000000000000..9a01e4d574888 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to + * a string that contains {@code value}. + * + * @since 3.3.0 + */ +@Evolving +public final class StringContains extends StringPredicate { + + public StringContains(NamedReference column, UTF8String value) { + super(column, value); + } + + @Override + public String toString() { return "STRING_CONTAINS(" + column.describe() + ", " + value + ")"; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java new file mode 100644 index 0000000000000..11b8317ba4895 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to + * a string that ends with {@code value}. + * + * @since 3.3.0 + */ +@Evolving +public final class StringEndsWith extends StringPredicate { + + public StringEndsWith(NamedReference column, UTF8String value) { + super(column, value); + } + + @Override + public String toString() { return "STRING_ENDS_WITH(" + column.describe() + ", " + value + ")"; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java new file mode 100644 index 0000000000000..ffe5d5dba45b3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Base class for {@link StringContains}, {@link StringStartsWith}, + * {@link StringEndsWith} + * + * @since 3.3.0 + */ +@Evolving +abstract class StringPredicate extends Filter { + protected final NamedReference column; + protected final UTF8String value; + + protected StringPredicate(NamedReference column, UTF8String value) { + this.column = column; + this.value = value; + } + + public NamedReference column() { return column; } + public UTF8String value() { return value; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + StringPredicate that = (StringPredicate) o; + return Objects.equals(column, that.column) && Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(column, value); + } + + @Override + public NamedReference[] references() { return new NamedReference[] { column }; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java new file mode 100644 index 0000000000000..38a5de1921cdc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A filter that evaluates to {@code true} iff the {@code column} evaluates to + * a string that starts with {@code value}. + * + * @since 3.3.0 + */ +@Evolving +public final class StringStartsWith extends StringPredicate { + + public StringStartsWith(NamedReference column, UTF8String value) { + super(column, value); + } + + @Override + public String toString() { + return "STRING_STARTS_WITH(" + column.describe() + ", " + value + ")"; + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala new file mode 100644 index 0000000000000..b457211b7f89f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.expressions.{FieldReference, Literal, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter._ +import org.apache.spark.sql.execution.datasources.v2.FiltersV2Suite.ref +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.unsafe.types.UTF8String + +class FiltersV2Suite extends SparkFunSuite { + + test("nested columns") { + val filter1 = new EqualTo(ref("a", "B"), LiteralValue(1, IntegerType)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a.B")) + assert(filter1.describe.equals("a.B = 1")) + + val filter2 = new EqualTo(ref("a", "b.c"), LiteralValue(1, IntegerType)) + assert(filter2.references.map(_.describe()).toSeq == Seq("a.`b.c`")) + assert(filter2.describe.equals("a.`b.c` = 1")) + + val filter3 = new EqualTo(ref("`a`.b", "c"), LiteralValue(1, IntegerType)) + assert(filter3.references.map(_.describe()).toSeq == Seq("```a``.b`.c")) + assert(filter3.describe.equals("```a``.b`.c = 1")) + } + + test("AlwaysTrue") { + val filter1 = new AlwaysTrue + val filter2 = new AlwaysTrue + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).length == 0) + assert(filter1.describe.equals("TRUE")) + } + + test("AlwaysFalse") { + val filter1 = new AlwaysFalse + val filter2 = new AlwaysFalse + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).length == 0) + assert(filter1.describe.equals("FALSE")) + } + + test("EqualTo") { + val filter1 = new EqualTo(ref("a"), LiteralValue(1, IntegerType)) + val filter2 = new EqualTo(ref("a"), LiteralValue(1, IntegerType)) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a = 1")) + } + + test("EqualNullSafe") { + val filter1 = new EqualNullSafe(ref("a"), LiteralValue(1, IntegerType)) + val filter2 = new EqualNullSafe(ref("a"), LiteralValue(1, IntegerType)) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a <=> 1")) + } + + test("GreaterThan") { + val filter1 = new GreaterThan(ref("a"), LiteralValue(1, IntegerType)) + val filter2 = new GreaterThan(ref("a"), LiteralValue(1, IntegerType)) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a > 1")) + } + + test("GreaterThanOrEqual") { + val filter1 = new GreaterThanOrEqual(ref("a"), LiteralValue(1, IntegerType)) + val filter2 = new GreaterThanOrEqual(ref("a"), LiteralValue(1, IntegerType)) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a >= 1")) + } + + test("LessThan") { + val filter1 = new LessThan(ref("a"), LiteralValue(1, IntegerType)) + val filter2 = new LessThan(ref("a"), LiteralValue(1, IntegerType)) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a < 1")) + } + + test("LessThanOrEqual") { + val filter1 = new LessThanOrEqual(ref("a"), LiteralValue(1, IntegerType)) + val filter2 = new LessThanOrEqual(ref("a"), LiteralValue(1, IntegerType)) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a <= 1")) + } + + test("In") { + val filter1 = new In(ref("a"), + Array(LiteralValue(1, IntegerType), LiteralValue(2, IntegerType), + LiteralValue(3, IntegerType), LiteralValue(4, IntegerType))) + val filter2 = new In(ref("a"), + Array(LiteralValue(4, IntegerType), LiteralValue(2, IntegerType), + LiteralValue(3, IntegerType), LiteralValue(1, IntegerType))) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a IN (1, 2, 3, 4)")) + val values: Array[Literal[_]] = new Array[Literal[_]](1000) + for (i <- 0 until 1000) { + values(i) = LiteralValue(i, IntegerType) + } + val filter3 = new In(ref("a"), values) + var expected = "a IN (" + for (i <- 0 until 50) { + expected += i + ", " + } + expected = expected.dropRight(2) // remove the last ", " + expected += "...)" + assert(filter3.describe.equals(expected)) + } + + test("IsNull") { + val filter1 = new IsNull(ref("a")) + val filter2 = new IsNull(ref("a")) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a IS NULL")) + } + + test("IsNotNull") { + val filter1 = new IsNotNull(ref("a")) + val filter2 = new IsNotNull(ref("a")) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("a IS NOT NULL")) + } + + test("Not") { + val filter1 = new Not(new LessThan(ref("a"), LiteralValue(1, IntegerType))) + val filter2 = new Not(new LessThan(ref("a"), LiteralValue(1, IntegerType))) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("NOT (a < 1)")) + } + + test("And") { + val filter1 = new And(new EqualTo(ref("a"), LiteralValue(1, IntegerType)), + new EqualTo(ref("b"), LiteralValue(1, IntegerType))) + val filter2 = new And(new EqualTo(ref("a"), LiteralValue(1, IntegerType)), + new EqualTo(ref("b"), LiteralValue(1, IntegerType))) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a", "b")) + assert(filter1.describe.equals("(a = 1) AND (b = 1)")) + } + + test("Or") { + val filter1 = new Or(new EqualTo(ref("a"), LiteralValue(1, IntegerType)), + new EqualTo(ref("b"), LiteralValue(1, IntegerType))) + val filter2 = new Or(new EqualTo(ref("a"), LiteralValue(1, IntegerType)), + new EqualTo(ref("b"), LiteralValue(1, IntegerType))) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a", "b")) + assert(filter1.describe.equals("(a = 1) OR (b = 1)")) + } + + test("StringStartsWith") { + val filter1 = new StringStartsWith(ref("a"), UTF8String.fromString("str")) + val filter2 = new StringStartsWith(ref("a"), UTF8String.fromString("str")) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("STRING_STARTS_WITH(a, str)")) + } + + test("StringEndsWith") { + val filter1 = new StringEndsWith(ref("a"), UTF8String.fromString("str")) + val filter2 = new StringEndsWith(ref("a"), UTF8String.fromString("str")) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("STRING_ENDS_WITH(a, str)")) + } + + test("StringContains") { + val filter1 = new StringContains(ref("a"), UTF8String.fromString("str")) + val filter2 = new StringContains(ref("a"), UTF8String.fromString("str")) + assert(filter1.equals(filter2)) + assert(filter1.references.map(_.describe()).toSeq == Seq("a")) + assert(filter1.describe.equals("STRING_CONTAINS(a, str)")) + } +} + +object FiltersV2Suite { + private[sql] def ref(parts: String*): FieldReference = { + new FieldReference(parts) + } +} From 871b2bd7890ecb0b7203b3310c607cbb8d62b5b5 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 22 Sep 2021 16:58:13 +0800 Subject: [PATCH 02/53] [SPARK-36760][SQL] Add interface SupportsPushDownV2Filters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: DB Tsai d_tsaiapple.com Co-Authored-By: Huaxin Gao huaxin_gaoapple.com ### What changes were proposed in this pull request? This is the 2nd PR for V2 Filter support. This PR does the following: - Add interface SupportsPushDownV2Filters Future work: - refactor `OrcFilters`, `ParquetFilters`, `JacksonParser`, `UnivocityParser` so both V1 file source and V2 file source can use them - For V2 file source: implement v2 filter -> parquet/orc filter. csv and Json don't have real filters, but also need to change the current code to have v2 filter -> `JacksonParser`/`UnivocityParser` - For V1 file source, keep what we currently have: v1 filter -> parquet/orc filter - We don't need v1filter.toV2 and v2filter.toV1 since we have two separate paths The reasons that we have reached the above conclusion: - The major motivation to implement V2Filter is to eliminate the unnecessary conversion between Catalyst types and Scala types when using Filters. - We provide this `SupportsPushDownV2Filters` in this PR so V2 data source (e.g. iceberg) can implement it and use V2 Filters - There are lots of work to implement v2 filters in the V2 file sources because of the following reasons: possible approaches for implementing V2Filter: 1. keep what we have for file source v1: v1 filter -> parquet/orc filter file source v2 we will implement v2 filter -> parquet/orc filter We don't need v1->v2 and v2->v1 problem with this approach: there are lots of code duplication 2. We will implement v2 filter -> parquet/orc filter file source v1: v1 filter -> v2 filter -> parquet/orc filter We will need V1 -> V2 This is the approach I am using in https://github.com/apache/spark/pull/33973 In that PR, I have v2 orc: v2 filter -> orc filter V1 orc: v1 -> v2 -> orc filter v2 csv: v2->v1, new UnivocityParser v1 csv: new UnivocityParser v2 Json: v2->v1, new JacksonParser v1 Json: new JacksonParser csv and Json don't have real filters, they just use filter references, should be OK to use either v1 and v2. Easier to use v1 because no need to change. I haven't finished parquet yet. The PR doesn't have the parquet V2Filter implementation, but I plan to have v2 parquet: v2 filter -> parquet filter v1 parquet: v1 -> v2 -> parquet filter Problem with this approach: 1. It's not easy to implement V1->V2 because V2 filter have `LiteralValue` and needs type info. We already lost the type information when we convert Expression filer to v1 filter. 2. parquet is OK Use Timestamp as example, parquet filter takes long for timestamp v2 parquet: v2 filter -> parquet filter timestamp Expression (Long) -> v2 filter (LiteralValue Long)-> parquet filter (Long) V1 parquet: v1 -> v2 -> parquet filter timestamp Expression (Long) -> v1 filter (timestamp) -> v2 filter (LiteralValue Long)-> parquet filter (Long) but we have problem for orc because orc filter takes java Timestamp v2 orc: v2 filter -> orc filter timestamp Expression (Long) -> v2 filter (LiteralValue Long)-> parquet filter (Timestamp) V1 orc: v1 -> v2 -> orc filter Expression (Long) -> v1 filter (timestamp) -> v2 filter (LiteralValue Long)-> parquet filter (Timestamp) This defeats the purpose of implementing v2 filters. 3. keep what we have for file source v1: v1 filter -> parquet/orc filter file source v2: v2 filter -> v1 filter -> parquet/orc filter We will need V2 -> V1 we have similar problem as approach 2. So the conclusion is: approach 1 (keep what we have for file source v1: v1 filter -> parquet/orc filter file source v2 we will implement v2 filter -> parquet/orc filter) is better, but there are lots of code duplication. We will need to refactor `OrcFilters`, `ParquetFilters`, `JacksonParser`, `UnivocityParser` so both V1 file source and V2 file source can use them. ### Why are the changes needed? Use V2Filters to eliminate the unnecessary conversion between Catalyst types and Scala types. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Added new UT Closes #34001 from huaxingao/v2filter. Lead-authored-by: Huaxin Gao Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../connector/expressions/filter/Filter.java | 4 +- .../read/SupportsPushDownV2Filters.java | 57 ++++++ .../datasources/v2/DataSourceV2Strategy.scala | 165 +++++++++++++++- .../datasources/v2/PushDownUtils.scala | 42 +++- .../v2/V2ScanRelationPushDown.scala | 8 +- .../JavaAdvancedDataSourceV2WithV2Filter.java | 186 ++++++++++++++++++ .../sql/connector/DataSourceV2Suite.scala | 142 +++++++++++++ 7 files changed, 593 insertions(+), 11 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java index 852837496a103..aa1fa082dc92c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions.filter; +import java.io.Serializable; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -27,7 +29,7 @@ * @since 3.3.0 */ @Evolving -public abstract class Filter implements Expression { +public abstract class Filter implements Expression, Serializable { protected static final NamedReference[] EMPTY_REFERENCE = new NamedReference[0]; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java new file mode 100644 index 0000000000000..1ba9939dd0849 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.filter.Filter; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down filters to the data source and reduce the size of the data to be read. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownV2Filters extends ScanBuilder { + + /** + * Pushes down filters, and returns filters that need to be evaluated after scanning. + *

+ * Rows should be returned from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. + */ + Filter[] pushFilters(Filter[] filters); + + /** + * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + *

+ * There are 3 kinds of filters: + *

    + *
  1. pushable filters which don't need to be evaluated again after scanning.
  2. + *
  3. pushable filters which still need to be evaluated after scanning, e.g. parquet row + * group filter.
  4. + *
  5. non-pushable filters.
  6. + *
+ *

+ * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + *

+ * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} + * is never called, empty array should be returned for this case. + */ + Filter[] pushedFilters(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1a50c320ea3e3..203ec510e4a9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,24 +18,30 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, EmptyRow, Expression, Literal, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.expressions.{FieldReference, Literal => V2Literal, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, Filter => V2Filter, GreaterThan => V2GreaterThan, GreaterThanOrEqual => V2GreaterThanOrEqual, In => V2In, IsNotNull => V2IsNotNull, IsNull => V2IsNull, LessThan => V2LessThan, LessThanOrEqual => V2LessThanOrEqual, Not => V2Not, Or => V2Or, StringContains => V2StringContains, StringEndsWith => V2StringEndsWith, StringStartsWith => V2StringStartsWith} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn, PushableColumnBase} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.{BaseRelation, TableScan} +import org.apache.spark.sql.types.{BooleanType, StringType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String class DataSourceV2Strategy(session: SparkSession) extends Strategy with PredicateHelper { @@ -427,3 +433,158 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _ => Nil } } + +private[sql] object DataSourceV2Strategy { + + private def translateLeafNodeFilterV2( + predicate: Expression, + pushableColumn: PushableColumnBase): Option[V2Filter] = predicate match { + case expressions.EqualTo(pushableColumn(name), Literal(v, t)) => + Some(new V2EqualTo(FieldReference(name), LiteralValue(v, t))) + case expressions.EqualTo(Literal(v, t), pushableColumn(name)) => + Some(new V2EqualTo(FieldReference(name), LiteralValue(v, t))) + + case expressions.EqualNullSafe(pushableColumn(name), Literal(v, t)) => + Some(new V2EqualNullSafe(FieldReference(name), LiteralValue(v, t))) + case expressions.EqualNullSafe(Literal(v, t), pushableColumn(name)) => + Some(new V2EqualNullSafe(FieldReference(name), LiteralValue(v, t))) + + case expressions.GreaterThan(pushableColumn(name), Literal(v, t)) => + Some(new V2GreaterThan(FieldReference(name), LiteralValue(v, t))) + case expressions.GreaterThan(Literal(v, t), pushableColumn(name)) => + Some(new V2LessThan(FieldReference(name), LiteralValue(v, t))) + + case expressions.LessThan(pushableColumn(name), Literal(v, t)) => + Some(new V2LessThan(FieldReference(name), LiteralValue(v, t))) + case expressions.LessThan(Literal(v, t), pushableColumn(name)) => + Some(new V2GreaterThan(FieldReference(name), LiteralValue(v, t))) + + case expressions.GreaterThanOrEqual(pushableColumn(name), Literal(v, t)) => + Some(new V2GreaterThanOrEqual(FieldReference(name), LiteralValue(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), pushableColumn(name)) => + Some(new V2LessThanOrEqual(FieldReference(name), LiteralValue(v, t))) + + case expressions.LessThanOrEqual(pushableColumn(name), Literal(v, t)) => + Some(new V2LessThanOrEqual(FieldReference(name), LiteralValue(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), pushableColumn(name)) => + Some(new V2GreaterThanOrEqual(FieldReference(name), LiteralValue(v, t))) + + case in @ expressions.InSet(pushableColumn(name), set) => + val values: Array[V2Literal[_]] = + set.toSeq.map(elem => LiteralValue(elem, in.dataType)).toArray + Some(new V2In(FieldReference(name), values)) + + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case in @ expressions.In(pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => + val hSet = list.map(_.eval(EmptyRow)) + Some(new V2In(FieldReference(name), + hSet.toArray.map(LiteralValue(_, in.value.dataType)))) + + case expressions.IsNull(pushableColumn(name)) => + Some(new V2IsNull(FieldReference(name))) + case expressions.IsNotNull(pushableColumn(name)) => + Some(new V2IsNotNull(FieldReference(name))) + + case expressions.StartsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(new V2StringStartsWith(FieldReference(name), v)) + + case expressions.EndsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(new V2StringEndsWith(FieldReference(name), v)) + + case expressions.Contains(pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(new V2StringContains(FieldReference(name), v)) + + case expressions.Literal(true, BooleanType) => + Some(new V2AlwaysTrue) + + case expressions.Literal(false, BooleanType) => + Some(new V2AlwaysFalse) + + case _ => None + } + + /** + * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. + * + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + */ + protected[sql] def translateFilterV2( + predicate: Expression, + supportNestedPredicatePushdown: Boolean): Option[V2Filter] = { + translateFilterV2WithMapping(predicate, None, supportNestedPredicatePushdown) + } + + /** + * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. + * + * @param predicate The input [[Expression]] to be translated as [[Filter]] + * @param translatedFilterToExpr An optional map from leaf node filter expressions to its + * translated [[Filter]]. The map is used for rebuilding + * [[Expression]] from [[Filter]]. + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + */ + protected[sql] def translateFilterV2WithMapping( + predicate: Expression, + translatedFilterToExpr: Option[mutable.HashMap[V2Filter, Expression]], + nestedPredicatePushdownEnabled: Boolean) + : Option[V2Filter] = { + predicate match { + case expressions.And(left, right) => + // See SPARK-12218 for detailed discussion + // It is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0) + // and we do not understand how to convert trim(b) = 'blah'. + // If we only convert a = 2, we will end up with + // (a = 2) OR (c > 0), which will generate wrong results. + // Pushing one leg of AND down is only safe to do at the top level. + // You can see ParquetFilters' createFilter for more details. + for { + leftFilter <- translateFilterV2WithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + rightFilter <- translateFilterV2WithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + } yield new V2And(leftFilter, rightFilter) + + case expressions.Or(left, right) => + for { + leftFilter <- translateFilterV2WithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + rightFilter <- translateFilterV2WithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + } yield new V2Or(leftFilter, rightFilter) + + case expressions.Not(child) => + translateFilterV2WithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) + .map(new V2Not(_)) + + case other => + val filter = translateLeafNodeFilterV2( + other, PushableColumn(nestedPredicatePushdownEnabled)) + if (filter.isDefined && translatedFilterToExpr.isDefined) { + translatedFilterToExpr.get(filter.get) = predicate + } + filter + } + } + + protected[sql] def rebuildExpressionFromFilter( + filter: V2Filter, + translatedFilterToExpr: mutable.HashMap[V2Filter, Expression]): Expression = { + filter match { + case and: V2And => + expressions.And(rebuildExpressionFromFilter(and.left, translatedFilterToExpr), + rebuildExpressionFromFilter(and.right, translatedFilterToExpr)) + case or: V2Or => + expressions.Or(rebuildExpressionFromFilter(or.left, translatedFilterToExpr), + rebuildExpressionFromFilter(or.right, translatedFilterToExpr)) + case not: V2Not => + expressions.Not(rebuildExpressionFromFilter(not.child, translatedFilterToExpr)) + case other => + translatedFilterToExpr.getOrElse(other, + throw new IllegalStateException("Failed to rebuild Expression for filter: " + filter)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index acc645741819e..c40f6d909a565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -24,10 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -40,7 +39,7 @@ object PushDownUtils extends PredicateHelper { */ def pushFilters( scanBuilder: ScanBuilder, - filters: Seq[Expression]): (Seq[sources.Filter], Seq[Expression]) = { + filters: Seq[Expression]): (Either[Seq[sources.Filter], Seq[V2Filter]], Seq[Expression]) = { scanBuilder match { case r: SupportsPushDownFilters => // A map from translated data source leaf node filters to original catalyst filter @@ -69,9 +68,38 @@ object PushDownUtils extends PredicateHelper { val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter => DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) } - (r.pushedFilters(), (untranslatableExprs ++ postScanFilters).toSeq) + (Left(r.pushedFilters()), (untranslatableExprs ++ postScanFilters).toSeq) - case _ => (Nil, filters) + case r: SupportsPushDownV2Filters => + // A map from translated data source leaf node filters to original catalyst filter + // expressions. For a `And`/`Or` predicate, it is possible that the predicate is partially + // pushed down. This map can be used to construct a catalyst filter expression from the + // input filter, or a superset(partial push down filter) of the input filter. + val translatedFilterToExpr = mutable.HashMap.empty[V2Filter, Expression] + val translatedFilters = mutable.ArrayBuffer.empty[V2Filter] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = + DataSourceV2Strategy.translateFilterV2WithMapping( + filterExpr, Some(translatedFilterToExpr), nestedPredicatePushdownEnabled = true) + if (translated.isEmpty) { + untranslatableExprs += filterExpr + } else { + translatedFilters += translated.get + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter => + DataSourceV2Strategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) + } + (Right(r.pushedFilters), (untranslatableExprs ++ postScanFilters).toSeq) + + case _ => (Left(Nil), filters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 046155b55cc2d..ec45a5d7853c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -58,12 +58,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( sHolder.builder, normalizedFiltersWithoutSubquery) + val pushedFiltersStr = if (pushedFilters.isLeft) { + pushedFilters.left.get.mkString(", ") + } else { + pushedFilters.right.get.mkString(", ") + } + val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery logInfo( s""" |Pushing operators to ${sHolder.relation.name} - |Pushed Filters: ${pushedFilters.mkString(", ")} + |Pushed Filters: $pushedFiltersStr |Post-Scan Filters: ${postScanFilters.mkString(",")} """.stripMargin) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java new file mode 100644 index 0000000000000..b92206c6a5444 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.TestingV2Source; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.expressions.filter.Filter; +import org.apache.spark.sql.connector.read.*; +import org.apache.spark.sql.connector.expressions.filter.GreaterThan; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class JavaAdvancedDataSourceV2WithV2Filter implements TestingV2Source { + + @Override + public Table getTable(CaseInsensitiveStringMap options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return new AdvancedScanBuilderWithV2Filter(); + } + }; + } + + static class AdvancedScanBuilderWithV2Filter implements ScanBuilder, Scan, + SupportsPushDownV2Filters, SupportsPushDownRequiredColumns { + + private StructType requiredSchema = TestingV2Source.schema(); + private Filter[] filters = new Filter[0]; + + @Override + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public StructType readSchema() { + return requiredSchema; + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + Filter[] supported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return gt.column().describe().equals("i") && gt.value().value() instanceof Integer; + } else { + return false; + } + }).toArray(Filter[]::new); + + Filter[] unsupported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return !gt.column().describe().equals("i") || !(gt.value().value() instanceof Integer); + } else { + return true; + } + }).toArray(Filter[]::new); + + this.filters = supported; + return unsupported; + } + + @Override + public Filter[] pushedFilters() { + return filters; + } + + @Override + public Scan build() { + return this; + } + + @Override + public Batch toBatch() { + return new AdvancedBatchWithV2Filter(requiredSchema, filters); + } + } + + public static class AdvancedBatchWithV2Filter implements Batch { + // Exposed for testing. + public StructType requiredSchema; + public Filter[] filters; + + AdvancedBatchWithV2Filter(StructType requiredSchema, Filter[] filters) { + this.requiredSchema = requiredSchema; + this.filters = filters; + } + + @Override + public InputPartition[] planInputPartitions() { + List res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.column().describe()) && f.value().value() instanceof Integer) { + lowerBound = (Integer) f.value().value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaRangeInputPartition(0, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 4) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 9) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + } + + return res.stream().toArray(InputPartition[]::new); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new AdvancedReaderFactoryWithV2Filter(requiredSchema); + } + } + + static class AdvancedReaderFactoryWithV2Filter implements PartitionReaderFactory { + StructType requiredSchema; + + AdvancedReaderFactoryWithV2Filter(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } + + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = current; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -current; + } + } + return new GenericInternalRow(values); + } + + @Override + public void close() throws IOException { + + } + }; + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index b42d48d873fee..2db9f0583a2ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter, GreaterThan => V2GreaterThan} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -54,6 +55,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS }.head } + private def getBatchWithV2Filter(query: DataFrame): AdvancedBatchWithV2Filter = { + query.queryExecution.executedPlan.collect { + case d: BatchScanExec => + d.batch.asInstanceOf[AdvancedBatchWithV2Filter] + }.head + } + private def getJavaBatch(query: DataFrame): JavaAdvancedDataSourceV2.AdvancedBatch = { query.queryExecution.executedPlan.collect { case d: BatchScanExec => @@ -61,6 +69,14 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS }.head } + private def getJavaBatchWithV2Filter( + query: DataFrame): JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter = { + query.queryExecution.executedPlan.collect { + case d: BatchScanExec => + d.batch.asInstanceOf[JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter] + }.head + } + test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -131,6 +147,66 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } + test("advanced implementation with V2 Filter") { + Seq(classOf[AdvancedDataSourceV2WithV2Filter], classOf[JavaAdvancedDataSourceV2WithV2Filter]) + .foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + + val q1 = df.select('j) + checkAnswer(q1, (0 until 10).map(i => Row(-i))) + if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { + val batch = getBatchWithV2Filter(q1) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) + } else { + val batch = getJavaBatchWithV2Filter(q1) + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) + } + + val q2 = df.filter('i > 3) + checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) + if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { + val batch = getBatchWithV2Filter(q2) + assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i", "j")) + } else { + val batch = getJavaBatchWithV2Filter(q2) + assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i", "j")) + } + + val q3 = df.select('i).filter('i > 6) + checkAnswer(q3, (7 until 10).map(i => Row(i))) + if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { + val batch = getBatchWithV2Filter(q3) + assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i")) + } else { + val batch = getJavaBatchWithV2Filter(q3) + assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i")) + } + + val q4 = df.select('j).filter('j < -10) + checkAnswer(q4, Nil) + if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { + val batch = getBatchWithV2Filter(q4) + // 'j < 10 is not supported by the testing data source. + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) + } else { + val batch = getJavaBatchWithV2Filter(q4) + // 'j < 10 is not supported by the testing data source. + assert(batch.filters.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) + } + } + } + } + test("columnar batch scan implementation") { Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -597,6 +673,72 @@ class AdvancedBatch(val filters: Array[Filter], val requiredSchema: StructType) } } +class AdvancedDataSourceV2WithV2Filter extends TestingV2Source { + + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new AdvancedScanBuilderWithV2Filter() + } + } +} + +class AdvancedScanBuilderWithV2Filter extends ScanBuilder + with Scan with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns { + + var requiredSchema = TestingV2Source.schema + var filters = Array.empty[V2Filter] + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + override def readSchema(): StructType = requiredSchema + + override def pushFilters(filters: Array[V2Filter]): Array[V2Filter] = { + val (supported, unsupported) = filters.partition { + case _: V2GreaterThan => true + case _ => false + } + this.filters = supported + unsupported + } + + override def pushedFilters(): Array[V2Filter] = filters + + override def build(): Scan = this + + override def toBatch: Batch = new AdvancedBatchWithV2Filter(filters, requiredSchema) +} + +class AdvancedBatchWithV2Filter( + val filters: Array[V2Filter], + val requiredSchema: StructType) extends Batch { + + override def planInputPartitions(): Array[InputPartition] = { + val lowerBound = filters.collectFirst { + case gt: V2GreaterThan => gt.value + } + + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + + if (lowerBound.isEmpty) { + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get.value.asInstanceOf[Integer] < 4) { + res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get.value.asInstanceOf[Integer] < 9) { + res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 10)) + } + + res.toArray + } + + override def createReaderFactory(): PartitionReaderFactory = { + new AdvancedReaderFactory(requiredSchema) + } +} + class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { From 8f73de6d9a38ea54a92465d03c58b2157cb3be16 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 28 Oct 2021 16:59:12 -0700 Subject: [PATCH 03/53] [SPARK-37020][SQL] DS V2 LIMIT push down ### What changes were proposed in this pull request? Push down limit to data source for better performance ### Why are the changes needed? For LIMIT, e.g. `SELECT * FROM table LIMIT 10`, Spark retrieves all the data from table and then returns 10 rows. If we can push LIMIT to data source side, the data transferred to Spark will be dramatically reduced. ### Does this PR introduce _any_ user-facing change? Yes. new interface `SupportsPushDownLimit` ### How was this patch tested? new test Closes #34291 from huaxingao/pushdownLimit. Authored-by: Huaxin Gao Signed-off-by: Huaxin Gao --- docs/sql-data-sources-jdbc.md | 9 +++ .../spark/sql/connector/read/ScanBuilder.java | 6 +- .../connector/read/SupportsPushDownLimit.java | 36 ++++++++++ .../sql/execution/DataSourceScanExec.scala | 4 +- .../datasources/DataSourceStrategy.scala | 3 + .../datasources/jdbc/JDBCOptions.scala | 4 ++ .../execution/datasources/jdbc/JDBCRDD.scala | 15 ++-- .../datasources/jdbc/JDBCRelation.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../datasources/v2/PushDownUtils.scala | 13 +++- .../v2/V2ScanRelationPushDown.scala | 30 ++++++-- .../datasources/v2/jdbc/JDBCScan.scala | 5 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 15 +++- .../apache/spark/sql/jdbc/DerbyDialect.scala | 4 ++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 12 ++++ .../spark/sql/jdbc/MsSqlServerDialect.scala | 3 + .../spark/sql/jdbc/TeradataDialect.scala | 3 + .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 69 ++++++++++++++++++- 18 files changed, 217 insertions(+), 24 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 315f47696475c..e8023ceb05388 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -246,6 +246,15 @@ logging into the data sources. read + + pushDownLimit + false + + The option to enable or disable LIMIT push-down into the JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down. + + read + + keytab (none) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index b46f620d4fedb..20c9d2e883923 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -21,9 +21,9 @@ /** * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ - * interfaces to do operator pushdown, and keep the operator pushdown result in the returned - * {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down - * aggregates or applies column pruning. + * interfaces to do operator push down, and keep the operator push down result in the returned + * {@link Scan}. When pushing down operators, the push down order is: + * filter -> aggregate -> limit -> column pruning. * * @since 3.0.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java new file mode 100644 index 0000000000000..7e50bf14d7817 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * push down LIMIT. Please note that the combination of LIMIT with other operations + * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownLimit extends ScanBuilder { + + /** + * Pushes down LIMIT to the data source. + */ + boolean pushLimit(int limit); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index efc459c8241fa..abdc6bdc0eb10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -104,6 +104,7 @@ case class RowDataSourceScanExec( filters: Set[Filter], handledFilters: Set[Filter], aggregation: Option[Aggregation], + limit: Option[Int], rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -153,7 +154,8 @@ case class RowDataSourceScanExec( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq), "PushedAggregates" -> aggString, - "PushedGroupby" -> groupByString) + "PushedGroupby" -> groupByString) ++ + limit.map(value => "PushedLimit" -> s"LIMIT $value") } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a53665fe2f0e4..d619d472d8253 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -336,6 +336,7 @@ object DataSourceStrategy Set.empty, Set.empty, None, + None, toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -410,6 +411,7 @@ object DataSourceStrategy pushedFilters.toSet, handledFilters, None, + None, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -433,6 +435,7 @@ object DataSourceStrategy pushedFilters.toSet, handledFilters, None, + None, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 8b2ae2beb6d4a..de0c2c6f19509 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -191,6 +191,9 @@ class JDBCOptions( // An option to allow/disallow pushing down aggregate into JDBC data source val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean + // An option to allow/disallow pushing down LIMIT into JDBC data source + val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean + // The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either // by --files option of spark-submit or manually val keytab = { @@ -263,6 +266,7 @@ object JDBCOptions { val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") + val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") val JDBC_TABLE_COMMENT = newOption("tableComment") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e024e4bb02102..7973850201826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -179,6 +179,8 @@ object JDBCRDD extends Logging { * @param options - JDBC options that contains url, table and other information. * @param outputSchema - The schema of the columns or aggregate columns to SELECT. * @param groupByColumns - The pushed down group by columns. + * @param limit - The pushed down limit. If the value is 0, it means no limit or limit + * is not pushed down. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ @@ -190,7 +192,8 @@ object JDBCRDD extends Logging { parts: Array[Partition], options: JDBCOptions, outputSchema: Option[StructType] = None, - groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = { + groupByColumns: Option[Array[String]] = None, + limit: Int = 0): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -208,7 +211,8 @@ object JDBCRDD extends Logging { parts, url, options, - groupByColumns) + groupByColumns, + limit) } } @@ -226,7 +230,8 @@ private[jdbc] class JDBCRDD( partitions: Array[Partition], url: String, options: JDBCOptions, - groupByColumns: Option[Array[String]]) + groupByColumns: Option[Array[String]], + limit: Int) extends RDD[InternalRow](sc, Nil) { /** @@ -349,8 +354,10 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) + val myLimitClause: String = dialect.getLimitClause(limit) + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + - s" $getGroupByClause" + s" $getGroupByClause $myLimitClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 8098fa0b83a95..ff9fcd493f600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -298,7 +298,8 @@ private[sql] case class JDBCRelation( requiredColumns: Array[String], finalSchema: StructType, filters: Array[Filter], - groupByColumns: Option[Array[String]]): RDD[Row] = { + groupByColumns: Option[Array[String]], + limit: Int): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -308,7 +309,8 @@ private[sql] case class JDBCRelation( parts, jdbcOptions, Some(finalSchema), - groupByColumns).asInstanceOf[RDD[Row]] + groupByColumns, + limit).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 203ec510e4a9d..d18ba2b045698 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -93,7 +93,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, - DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) => + DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate, limit), output)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError( @@ -101,12 +101,14 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } val rdd = v1Relation.buildScan() val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd) + val dsScan = RowDataSourceScanExec( output, output.toStructType, Set.empty, pushed.toSet, aggregate, + limit, unsafeRowRDD, v1Relation, tableIdentifier = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index c40f6d909a565..cc503dd55c83d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -135,6 +135,17 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down LIMIT to the data source Scan + */ + def pushLimit(scanBuilder: ScanBuilder, limit: Int): Boolean = { + scanBuilder match { + case s: SupportsPushDownLimit => + s.pushLimit(limit) + case _ => false + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index ec45a5d7853c9..960a1ea60598b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} @@ -36,7 +36,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { - applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan)))) + applyColumnPruning(applyLimit(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))) } private def createScanBuilder(plan: LogicalPlan) = plan.transform { @@ -225,6 +225,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { withProjection } + def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { + case globalLimit @ Limit(IntegerLiteral(limitValue), child) => + child match { + case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => + val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue) + if (limitPushed) { + sHolder.setLimit(Some(limitValue)) + } + globalLimit + case _ => globalLimit + } + } + private def getWrappedScan( scan: Scan, sHolder: ScanBuilderHolder, @@ -236,7 +249,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - V1ScanWrapper(v1, pushedFilters, aggregation) + V1ScanWrapper(v1, pushedFilters, aggregation, sHolder.pushedLimit) case _ => scan } } @@ -245,13 +258,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case class ScanBuilderHolder( output: Seq[AttributeReference], relation: DataSourceV2Relation, - builder: ScanBuilder) extends LeafNode + builder: ScanBuilder) extends LeafNode { + var pushedLimit: Option[Int] = None + private[sql] def setLimit(limit: Option[Int]): Unit = pushedLimit = limit +} + // A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by // the physical v1 scan node. case class V1ScanWrapper( v1Scan: V1Scan, handledFilters: Seq[sources.Filter], - pushedAggregate: Option[Aggregation]) extends Scan { + pushedAggregate: Option[Aggregation], + pushedLimit: Option[Int]) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index ef42691e5ca94..94d9d1433f9d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -28,7 +28,8 @@ case class JDBCScan( prunedSchema: StructType, pushedFilters: Array[Filter], pushedAggregateColumn: Array[String] = Array(), - groupByColumns: Option[Array[String]]) extends V1Scan { + groupByColumns: Option[Array[String]], + pushedLimit: Int) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -43,7 +44,7 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns) + relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, pushedLimit) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index b0de7c015c91a..14826748dd432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.jdbc.JdbcDialects @@ -36,6 +36,7 @@ case class JDBCScanBuilder( with SupportsPushDownFilters with SupportsPushDownRequiredColumns with SupportsPushDownAggregates + with SupportsPushDownLimit with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis @@ -44,6 +45,16 @@ case class JDBCScanBuilder( private var finalSchema = schema + private var pushedLimit = 0 + + override def pushLimit(limit: Int): Boolean = { + if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) { + pushedLimit = limit + return true + } + false + } + override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) @@ -123,6 +134,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols) + pushedAggregateList, pushedGroupByCols, pushedLimit) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 020733aaee8c0..ecb514abac01c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -57,4 +57,8 @@ private object DerbyDialect extends JdbcDialect { override def getTableCommentQuery(table: String, comment: String): String = { throw QueryExecutionErrors.commentOnTableUnsupportedError() } + + // ToDo: use fetch first n rows only for limit, e.g. + // select * from employee fetch first 10 rows only; + override def supportsLimit(): Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index aa957113b5ca5..5a0b9cb3a845e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -296,6 +296,18 @@ abstract class JdbcDialect extends Serializable with Logging{ def classifyException(message: String, e: Throwable): AnalysisException = { new AnalysisException(message, cause = Some(e)) } + + /** + * returns the LIMIT clause for the SELECT statement + */ + def getLimitClause(limit: Integer): String = { + if (limit > 0 ) s"LIMIT $limit" else "" + } + + /** + * returns whether the dialect supports limit or not + */ + def supportsLimit(): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index ea9834830e373..8dad5ef8e1eae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -118,4 +118,7 @@ private object MsSqlServerDialect extends JdbcDialect { override def getTableCommentQuery(table: String, comment: String): String = { throw QueryExecutionErrors.commentOnTableUnsupportedError() } + + // ToDo: use top n to get limit, e.g. select top 100 * from employee; + override def supportsLimit(): Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 58fe62cb6e088..2a776bdb7ab04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -55,4 +55,7 @@ private case object TeradataDialect extends JdbcDialect { override def renameTable(oldTable: String, newTable: String): String = { s"RENAME TABLE $oldTable TO $newTable" } + + // ToDo: use top n to get limit, e.g. select top 100 * from employee; + override def supportsLimit(): Boolean = false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 526dad91e5e19..4e82ec34837cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -21,10 +21,10 @@ import java.sql.{Connection, DriverManager} import java.util.Properties import org.apache.spark.SparkConf -import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{lit, sum, udf} import org.apache.spark.sql.test.SharedSparkSession @@ -42,6 +42,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .set("spark.sql.catalog.h2.url", url) .set("spark.sql.catalog.h2.driver", "org.h2.Driver") .set("spark.sql.catalog.h2.pushDownAggregate", "true") + .set("spark.sql.catalog.h2.pushDownLimit", "true") private def withConnection[T](f: Connection => T): T = { val conn = DriverManager.getConnection(url, new Properties()) @@ -92,6 +93,70 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2))) } + test("simple scan with LIMIT") { + val df1 = spark.read.table("h2.test.employee") + .where($"dept" === 1).limit(1) + checkPushedLimit(df1, true, 1) + checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .limit(1) + checkPushedLimit(df2, true, 1) + checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) + + val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") + val scan = df3.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq("NAME"))) + checkPushedLimit(df3, true, 1) + checkAnswer(df3, Seq(Row("alex"))) + + val df4 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .limit(1) + checkPushedLimit(df4, false, 0) + checkAnswer(df4, Seq(Row(1, 19000.00))) + + val df5 = spark.read + .table("h2.test.employee") + .sort("SALARY") + .limit(1) + checkPushedLimit(df5, false, 0) + checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df6 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .limit(1) + // LIMIT is pushed down only if all the filters are pushed down + checkPushedLimit(df6, false, 0) + checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) + } + + private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = { + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => scan match { + case v1: V1ScanWrapper => + if (pushed) { + assert(v1.pushedLimit.nonEmpty && v1.pushedLimit.get === limit) + } else { + assert(v1.pushedLimit.isEmpty) + } + } + } + } + test("scan with filter push-down") { val df = spark.table("h2.test.people").filter($"id" > 1) val filters = df.queryExecution.optimizedPlan.collect { From 5dbcdc01e56050f446d484049ee5e1602d0acdf2 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 4 Nov 2021 22:26:34 +0800 Subject: [PATCH 04/53] [SPARK-37038][SQL] DSV2 Sample Push Down ### What changes were proposed in this pull request? Push down Sample to data source for better performance. If Sample is pushed down, it will be removed from logical plan so it will not be applied at Spark any more. Current Plan without Sample push down: ``` == Parsed Logical Plan == 'Project [*] +- 'Sample 0.0, 0.8, false, 157 +- 'UnresolvedRelation [postgresql, new_table], [], false == Analyzed Logical Plan == col1: int, col2: int Project [col1#163, col2#164] +- Sample 0.0, 0.8, false, 157 +- SubqueryAlias postgresql.new_table +- RelationV2[col1#163, col2#164] new_table == Optimized Logical Plan == Sample 0.0, 0.8, false, 157 +- RelationV2[col1#163, col2#164] new_table == Physical Plan == *(1) Sample 0.0, 0.8, false, 157 +- *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$16dde4769 [col1#163,col2#164] PushedAggregates: [], PushedFilters: [], PushedGroupby: [], ReadSchema: struct ``` after Sample push down: ``` == Parsed Logical Plan == 'Project [*] +- 'Sample 0.0, 0.8, false, 187 +- 'UnresolvedRelation [postgresql, new_table], [], false == Analyzed Logical Plan == col1: int, col2: int Project [col1#163, col2#164] +- Sample 0.0, 0.8, false, 187 +- SubqueryAlias postgresql.new_table +- RelationV2[col1#163, col2#164] new_table == Optimized Logical Plan == RelationV2[col1#163, col2#164] new_table == Physical Plan == *(1) Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$165b57543 [col1#163,col2#164] PushedAggregates: [], PushedFilters: [], PushedGroupby: [], PushedSample: TABLESAMPLE 0.0 0.8 false 187, ReadSchema: struct ``` The new interface is implemented using JDBC for POC and end to end test. TABLESAMPLE is not supported by all the databases. It is implemented using postgresql in this PR. ### Why are the changes needed? Reduce IO and improve performance. For SAMPLE, e.g. `SELECT * FROM t TABLESAMPLE (1 PERCENT)`, Spark retrieves all the data from table and then return 1% rows. It will dramatically reduce the transferred data size and improve performance if we can push Sample to data source side. ### Does this PR introduce any user-facing change? Yes. new interface `SupportsPushDownTableSample` ### How was this patch tested? New test Closes #34451 from huaxingao/sample. Authored-by: Huaxin Gao Signed-off-by: Wenchen Fan --- docs/sql-data-sources-jdbc.md | 11 +- .../jdbc/v2/PostgresIntegrationSuite.scala | 5 + .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 111 +++++++++++++++++- .../spark/sql/connector/read/ScanBuilder.java | 2 +- .../read/SupportsPushDownTableSample.java | 39 ++++++ .../sql/execution/DataSourceScanExec.scala | 16 +-- .../datasources/DataSourceStrategy.scala | 10 +- .../datasources/jdbc/JDBCOptions.scala | 8 +- .../execution/datasources/jdbc/JDBCRDD.scala | 15 ++- .../datasources/jdbc/JDBCRelation.scala | 3 + .../datasources/v2/DataSourceV2Strategy.scala | 7 +- .../datasources/v2/PushDownUtils.scala | 14 ++- .../datasources/v2/PushedDownOperators.scala | 28 +++++ .../datasources/v2/TableSampleInfo.scala | 24 ++++ .../v2/V2ScanRelationPushDown.scala | 41 +++++-- .../datasources/v2/jdbc/JDBCScan.scala | 5 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 37 ++++-- .../apache/spark/sql/jdbc/JdbcDialects.scala | 6 + .../spark/sql/jdbc/PostgresDialect.scala | 10 ++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 19 ++- 20 files changed, 364 insertions(+), 47 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index e8023ceb05388..99e1a963a7954 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -250,7 +250,16 @@ logging into the data sources. pushDownLimit false - The option to enable or disable LIMIT push-down into the JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down. + The option to enable or disable LIMIT push-down into V2 JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down. + + read + + + + pushDownTableSample + false + + The option to enable or disable TABLESAMPLE push-down into V2 JDBC data source. The default value is false, in which case Spark does not push down TABLESAMPLE to the JDBC data source. Otherwise, if value sets to true, TABLESAMPLE is pushed down to the JDBC data source. read diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 932ddb90f6cb0..3ccf051fea52b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -49,6 +49,9 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.postgresql.pushDownTableSample", "true") + .set("spark.sql.catalog.postgresql.pushDownLimit", "true") + override def dataPreparation(conn: Connection): Unit = {} override def testUpdateColumnType(tbl: String): Unit = { @@ -75,4 +78,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def supportsTableSample: Boolean = true } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 1afe26afe1a9f..406cc41521e1c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.jdbc.v2 import org.apache.log4j.Level import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sample} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -27,6 +30,8 @@ import org.apache.spark.tags.DockerTest @DockerTest private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFunSuite { + import testImplicits._ + val catalogName: String // dialect specific update column type test def testUpdateColumnType(tbl: String): Unit @@ -180,5 +185,109 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu testCreateTableWithProperty(s"$catalogName.new_table") } } -} + def supportsTableSample: Boolean = false + + private def samplePushed(df: DataFrame): Boolean = { + val sample = df.queryExecution.optimizedPlan.collect { + case s: Sample => s + } + sample.isEmpty + } + + private def filterPushed(df: DataFrame): Boolean = { + val filter = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + filter.isEmpty + } + + private def limitPushed(df: DataFrame, limit: Int): Boolean = { + val filter = df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + return v1.pushedDownOperators.limit == Some(limit) + } + } + false + } + + private def columnPruned(df: DataFrame, col: String): Boolean = { + val scan = df.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + scan.schema.names.sameElements(Seq(col)) + } + + test("SPARK-37038: Test TABLESAMPLE") { + if (supportsTableSample) { + withTable(s"$catalogName.new_table") { + sql(s"CREATE TABLE $catalogName.new_table (col1 INT, col2 INT)") + spark.range(10).select($"id" * 2, $"id" * 2 + 1).write.insertInto(s"$catalogName.new_table") + + // sample push down + column pruning + val df1 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + + " REPEATABLE (12345)") + assert(samplePushed(df1)) + assert(columnPruned(df1, "col1")) + assert(df1.collect().length < 10) + + // sample push down only + val df2 = sql(s"SELECT * FROM $catalogName.new_table TABLESAMPLE (50 PERCENT)" + + " REPEATABLE (12345)") + assert(samplePushed(df2)) + assert(df2.collect().length < 10) + + // sample(BUCKET ... OUT OF) push down + limit push down + column pruning + val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + + " LIMIT 2") + assert(samplePushed(df3)) + assert(limitPushed(df3, 2)) + assert(columnPruned(df3, "col1")) + assert(df3.collect().length == 2) + + // sample(... PERCENT) push down + limit push down + column pruning + val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" + + " TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2") + assert(samplePushed(df4)) + assert(limitPushed(df4, 2)) + assert(columnPruned(df4, "col1")) + assert(df4.collect().length == 2) + + // sample push down + filter push down + limit push down + val df5 = sql(s"SELECT * FROM $catalogName.new_table" + + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") + assert(samplePushed(df5)) + assert(filterPushed(df5)) + assert(limitPushed(df5, 2)) + assert(df5.collect().length == 2) + + // sample + filter + limit + column pruning + // sample pushed down, filer/limit not pushed down, column pruned + // Todo: push down filter/limit + val df6 = sql(s"SELECT col1 FROM $catalogName.new_table" + + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") + assert(samplePushed(df6)) + assert(!filterPushed(df6)) + assert(!limitPushed(df6, 2)) + assert(columnPruned(df6, "col1")) + assert(df6.collect().length == 2) + + // sample + limit + // Push down order is sample -> filter -> limit + // only limit is pushed down because in this test sample is after limit + val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5) + assert(!samplePushed(df7)) + assert(limitPushed(df7, 2)) + + // sample + filter + // Push down order is sample -> filter -> limit + // only filter is pushed down because in this test sample is after filter + val df8 = spark.read.table(s"$catalogName.new_table").where($"col1" > 1).sample(0.5) + assert(!samplePushed(df8)) + assert(filterPushed(df8)) + assert(df8.collect().length < 10) + } + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index 20c9d2e883923..27ee534d804ff 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -23,7 +23,7 @@ * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ * interfaces to do operator push down, and keep the operator push down result in the returned * {@link Scan}. When pushing down operators, the push down order is: - * filter -> aggregate -> limit -> column pruning. + * sample -> filter -> aggregate -> limit -> column pruning. * * @since 3.0.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java new file mode 100644 index 0000000000000..3630feb4680ea --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * push down SAMPLE. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownTableSample extends ScanBuilder { + + /** + * Pushes down SAMPLE to the data source. + */ + boolean pushTableSample( + double lowerBound, + double upperBound, + boolean withReplacement, + long seed); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index abdc6bdc0eb10..3e6f0fd222ecb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} +import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, Filter} @@ -103,8 +103,7 @@ case class RowDataSourceScanExec( requiredSchema: StructType, filters: Set[Filter], handledFilters: Set[Filter], - aggregation: Option[Aggregation], - limit: Option[Int], + pushedDownOperators: PushedDownOperators, rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -135,9 +134,9 @@ case class RowDataSourceScanExec( def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") - val (aggString, groupByString) = if (aggregation.nonEmpty) { - (seqToString(aggregation.get.aggregateExpressions), - seqToString(aggregation.get.groupByColumns)) + val (aggString, groupByString) = if (pushedDownOperators.aggregation.nonEmpty) { + (seqToString(pushedDownOperators.aggregation.get.aggregateExpressions), + seqToString(pushedDownOperators.aggregation.get.groupByColumns)) } else { ("[]", "[]") } @@ -155,7 +154,10 @@ case class RowDataSourceScanExec( "PushedFilters" -> seqToString(markedFilters.toSeq), "PushedAggregates" -> aggString, "PushedGroupby" -> groupByString) ++ - limit.map(value => "PushedLimit" -> s"LIMIT $value") + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ + pushedDownOperators.sample.map(v => "PushedSample" -> + s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" + ) } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d619d472d8253..e8fb9ca3e46c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Coun import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.sources._ @@ -335,8 +336,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - None, - None, + PushedDownOperators(None, None, None), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -410,8 +410,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - None, - None, + PushedDownOperators(None, None, None), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -434,8 +433,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - None, - None, + PushedDownOperators(None, None, None), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index de0c2c6f19509..8e047d7f7c7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -191,9 +191,14 @@ class JDBCOptions( // An option to allow/disallow pushing down aggregate into JDBC data source val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean - // An option to allow/disallow pushing down LIMIT into JDBC data source + // An option to allow/disallow pushing down LIMIT into V2 JDBC data source + // This only applies to Data Source V2 JDBC val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean + // An option to allow/disallow pushing down TABLESAMPLE into JDBC data source + // This only applies to Data Source V2 JDBC + val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean + // The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either // by --files option of spark-submit or manually val keytab = { @@ -267,6 +272,7 @@ object JDBCOptions { val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit") + val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") val JDBC_TABLE_COMMENT = newOption("tableComment") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 7973850201826..1b8d33b94fbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -181,6 +182,7 @@ object JDBCRDD extends Logging { * @param groupByColumns - The pushed down group by columns. * @param limit - The pushed down limit. If the value is 0, it means no limit or limit * is not pushed down. + * @param sample - The pushed down tableSample. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ @@ -193,6 +195,7 @@ object JDBCRDD extends Logging { options: JDBCOptions, outputSchema: Option[StructType] = None, groupByColumns: Option[Array[String]] = None, + sample: Option[TableSampleInfo] = None, limit: Int = 0): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) @@ -212,6 +215,7 @@ object JDBCRDD extends Logging { url, options, groupByColumns, + sample, limit) } } @@ -231,6 +235,7 @@ private[jdbc] class JDBCRDD( url: String, options: JDBCOptions, groupByColumns: Option[Array[String]], + sample: Option[TableSampleInfo], limit: Int) extends RDD[InternalRow](sc, Nil) { @@ -354,10 +359,16 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) + val myTableSampleClause: String = if (sample.nonEmpty) { + JdbcDialects.get(url).getTableSample(sample.get) + } else { + "" + } + val myLimitClause: String = dialect.getLimitClause(limit) - val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + - s" $getGroupByClause $myLimitClause" + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + + s" $myWhereClause $getGroupByClause $myLimitClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index ff9fcd493f600..cd1eae89ee890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ @@ -299,6 +300,7 @@ private[sql] case class JDBCRelation( finalSchema: StructType, filters: Array[Filter], groupByColumns: Option[Array[String]], + tableSample: Option[TableSampleInfo], limit: Int): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( @@ -310,6 +312,7 @@ private[sql] case class JDBCRelation( jdbcOptions, Some(finalSchema), groupByColumns, + tableSample, limit).asInstanceOf[RDD[Row]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d18ba2b045698..d07c29e080265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -92,8 +92,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(project, filters, - DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate, limit), output)) => + case PhysicalOperation(project, filters, DataSourceV2ScanRelation( + _, V1ScanWrapper(scan, pushed, pushedDownOperators), output)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError( @@ -107,8 +107,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat output.toStructType, Set.empty, pushed.toSet, - aggregate, - limit, + pushedDownOperators, unsafeRowRDD, v1Relation, tableIdentifier = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index cc503dd55c83d..a98b8979d3e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -135,6 +135,18 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down TableSample to the data source Scan + */ + def pushTableSample(scanBuilder: ScanBuilder, sample: TableSampleInfo): Boolean = { + scanBuilder match { + case s: SupportsPushDownTableSample => + s.pushTableSample( + sample.lowerBound, sample.upperBound, sample.withReplacement, sample.seed) + case _ => false + } + } + /** * Pushes down LIMIT to the data source Scan */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala new file mode 100644 index 0000000000000..c21354d646164 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation + +/** + * Pushed down operators + */ +case class PushedDownOperators( + aggregation: Option[Aggregation], + sample: Option[TableSampleInfo], + limit: Option[Int]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala new file mode 100644 index 0000000000000..cb4fb9eb0809a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +case class TableSampleInfo( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 960a1ea60598b..f73f831903364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} @@ -36,7 +36,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { - applyColumnPruning(applyLimit(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))) + applyColumnPruning( + applyLimit(pushDownAggregates(pushDownFilters(pushDownSample(createScanBuilder(plan)))))) } private def createScanBuilder(plan: LogicalPlan) = plan.transform { @@ -225,13 +226,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { withProjection } + def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform { + case sample: Sample => sample.child match { + case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => + val tableSample = TableSampleInfo( + sample.lowerBound, + sample.upperBound, + sample.withReplacement, + sample.seed) + val pushed = PushDownUtils.pushTableSample(sHolder.builder, tableSample) + if (pushed) { + sHolder.pushedSample = Some(tableSample) + sample.child + } else { + sample + } + + case _ => sample + } + } + def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => child match { case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue) if (limitPushed) { - sHolder.setLimit(Some(limitValue)) + sHolder.pushedLimit = Some(limitValue) } globalLimit case _ => globalLimit @@ -249,7 +270,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - V1ScanWrapper(v1, pushedFilters, aggregation, sHolder.pushedLimit) + val pushedDownOperators = + PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit) + V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } } @@ -260,16 +283,16 @@ case class ScanBuilderHolder( relation: DataSourceV2Relation, builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None - private[sql] def setLimit(limit: Option[Int]): Unit = pushedLimit = limit + + var pushedSample: Option[TableSampleInfo] = None } -// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by -// the physical v1 scan node. +// A wrapper for v1 scan to carry the translated filters and the handled ones, along with +// other pushed down operators. This is required by the physical v1 scan node. case class V1ScanWrapper( v1Scan: V1Scan, handledFilters: Seq[sources.Filter], - pushedAggregate: Option[Aggregation], - pushedLimit: Option[Int]) extends Scan { + pushedDownOperators: PushedDownOperators) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index 94d9d1433f9d4..ff79d1a5c4144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -20,6 +20,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} import org.apache.spark.sql.types.StructType @@ -29,6 +30,7 @@ case class JDBCScan( pushedFilters: Array[Filter], pushedAggregateColumn: Array[String] = Array(), groupByColumns: Option[Array[String]], + tableSample: Option[TableSampleInfo], pushedLimit: Int) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -44,7 +46,8 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, pushedLimit) + relation.buildScan( + columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 14826748dd432..7605b03f49ea5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -21,9 +21,10 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -37,6 +38,7 @@ case class JDBCScanBuilder( with SupportsPushDownRequiredColumns with SupportsPushDownAggregates with SupportsPushDownLimit + with SupportsPushDownTableSample with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis @@ -45,15 +47,9 @@ case class JDBCScanBuilder( private var finalSchema = schema - private var pushedLimit = 0 + private var tableSample: Option[TableSampleInfo] = None - override def pushLimit(limit: Int): Boolean = { - if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) { - pushedLimit = limit - return true - } - false - } + private var pushedLimit = 0 override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { @@ -109,6 +105,27 @@ case class JDBCScanBuilder( } } + override def pushTableSample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long): Boolean = { + if (jdbcOptions.pushDownTableSample && + JdbcDialects.get(jdbcOptions.url).supportsTableSample) { + this.tableSample = Some(TableSampleInfo(lowerBound, upperBound, withReplacement, seed)) + return true + } + false + } + + override def pushLimit(limit: Int): Boolean = { + if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) { + pushedLimit = limit + return true + } + false + } + override def pruneColumns(requiredSchema: StructType): Unit = { // JDBC doesn't support nested column pruning. // TODO (SPARK-32593): JDBC support nested column and nested column pruning. @@ -134,6 +151,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols, pushedLimit) + pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 5a0b9cb3a845e..75b0de987f1ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -308,6 +309,11 @@ abstract class JdbcDialect extends Serializable with Logging{ * returns whether the dialect supports limit or not */ def supportsLimit(): Boolean = true + + def supportsTableSample: Boolean = false + + def getTableSample(sample: TableSampleInfo): String = + throw new UnsupportedOperationException("TableSample is not supported by this data source") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 3ce785ed844c5..317ae19ed914b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -21,6 +21,7 @@ import java.sql.{Connection, Types} import java.util.Locale import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -154,4 +155,13 @@ private object PostgresDialect extends JdbcDialect { val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL" s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable" } + + override def supportsTableSample: Boolean = true + + override def getTableSample(sample: TableSampleInfo): String = { + // hard-coded to BERNOULLI for now because Spark doesn't have a way to specify sample + // method name + s"TABLESAMPLE BERNOULLI" + + s" (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE (${sample.seed})" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 4e82ec34837cc..3c95fa2e66de7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -93,6 +93,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2))) } + // TABLESAMPLE ({integer_expression | decimal_expression} PERCENT) and + // TABLESAMPLE (BUCKET integer_expression OUT OF integer_expression) + // are tested in JDBC dialect tests because TABLESAMPLE is not supported by all the DBMS + test("TABLESAMPLE (integer_expression ROWS) is the same as LIMIT") { + val df = sql("SELECT NAME FROM h2.test.employee TABLESAMPLE (3 ROWS)") + val scan = df.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq("NAME"))) + checkPushedLimit(df, true, 3) + checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) + } + test("simple scan with LIMIT") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) @@ -146,12 +159,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = { df.queryExecution.optimizedPlan.collect { - case DataSourceV2ScanRelation(_, scan, _) => scan match { + case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => if (pushed) { - assert(v1.pushedLimit.nonEmpty && v1.pushedLimit.get === limit) + assert(v1.pushedDownOperators.limit === Some(limit)) } else { - assert(v1.pushedLimit.isEmpty) + assert(v1.pushedDownOperators.limit.isEmpty) } } } From 39b29d7920a94d2c907d4d599d6a3ad0a959c69c Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 3 Dec 2021 21:12:25 +0800 Subject: [PATCH 05/53] [SPARK-37286][SQL] Move compileAggregates from JDBCRDD to JdbcDialect ### What changes were proposed in this pull request? Currently, the method `compileAggregates` is a member of `JDBCRDD`. But it is not reasonable, because the JDBC source knowns how to compile aggregate expressions to itself's dialect well. ### Why are the changes needed? JDBC source knowns how to compile aggregate expressions to itself's dialect well. After this PR, we can extend the pushdown(e.g. aggregate) based on different dialect between different JDBC database. There are two situations: First, database A and B implement a different number of aggregate functions that meet the SQL standard. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the inner implementation. ### How was this patch tested? Jenkins tests. Closes #34554 from beliefer/SPARK-37286. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../execution/datasources/jdbc/JDBCRDD.scala | 29 ----------------- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 6 ++-- .../apache/spark/sql/jdbc/JdbcDialects.scala | 31 +++++++++++++++++++ 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 1b8d33b94fbd2..394ba3f8bb8c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,6 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -140,34 +139,6 @@ object JDBCRDD extends Logging { }) } - def compileAggregates( - aggregates: Seq[AggregateFunc], - dialect: JdbcDialect): Option[Seq[String]] = { - def quote(colName: String): String = dialect.quoteIdentifier(colName) - - Some(aggregates.map { - case min: Min => - if (min.column.fieldNames.length != 1) return None - s"MIN(${quote(min.column.fieldNames.head)})" - case max: Max => - if (max.column.fieldNames.length != 1) return None - s"MAX(${quote(max.column.fieldNames.head)})" - case count: Count => - if (count.column.fieldNames.length != 1) return None - val distinct = if (count.isDistinct) "DISTINCT " else "" - val column = quote(count.column.fieldNames.head) - s"COUNT($distinct$column)" - case sum: Sum => - if (sum.column.fieldNames.length != 1) return None - val distinct = if (sum.isDistinct) "DISTINCT " else "" - val column = quote(sum.column.fieldNames.head) - s"SUM($distinct$column)" - case _: CountStar => - s"COUNT(*)" - case _ => return None - }) - } - /** * Build and return JDBCRDD from the given information. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 7605b03f49ea5..d3c141ed53c5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -72,8 +72,8 @@ case class JDBCScanBuilder( if (!jdbcOptions.pushDownAggregate) return false val dialect = JdbcDialects.get(jdbcOptions.url) - val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect) - if (compiledAgg.isEmpty) return false + val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate(_)) + if (compiledAggs.length != aggregation.aggregateExpressions.length) return false val groupByCols = aggregation.groupByColumns.map { col => if (col.fieldNames.length != 1) return false @@ -84,7 +84,7 @@ case class JDBCScanBuilder( // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") => // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee" // GROUP BY "DEPT", "NAME" - val selectList = groupByCols ++ compiledAgg.get + val selectList = groupByCols ++ compiledAggs val groupByClause = if (groupByCols.isEmpty) { "" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 75b0de987f1ca..f96323ec16a3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -190,6 +191,36 @@ abstract class JdbcDialect extends Serializable with Logging{ case _ => value } + /** + * Converts aggregate function to String representing a SQL expression. + * @param aggregate The aggregate function to be converted. + * @return Converted value. + */ + @Since("3.3.0") + def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + aggFunction match { + case min: Min => + if (min.column.fieldNames.length != 1) return None + Some(s"MIN(${quoteIdentifier(min.column.fieldNames.head)})") + case max: Max => + if (max.column.fieldNames.length != 1) return None + Some(s"MAX(${quoteIdentifier(max.column.fieldNames.head)})") + case count: Count => + if (count.column.fieldNames.length != 1) return None + val distinct = if (count.isDistinct) "DISTINCT " else "" + val column = quoteIdentifier(count.column.fieldNames.head) + Some(s"COUNT($distinct$column)") + case sum: Sum => + if (sum.column.fieldNames.length != 1) return None + val distinct = if (sum.isDistinct) "DISTINCT " else "" + val column = quoteIdentifier(sum.column.fieldNames.head) + Some(s"SUM($distinct$column)") + case _: CountStar => + Some(s"COUNT(*)") + case _ => None + } + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. From 560cabd5cbda59497db61ae4292e27f5e6b60b8a Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 3 Dec 2021 11:16:39 -0600 Subject: [PATCH 06/53] [SPARK-37286][DOCS][FOLLOWUP] Fix the wrong parameter name for Javadoc ### What changes were proposed in this pull request? This PR fixes an issue that the Javadoc generation fails due to the wrong parameter name of a method added in SPARK-37286 (#34554). https://github.com/apache/spark/runs/4409267346?check_suite_focus=true#step:9:5081 ### Why are the changes needed? To keep the build clean. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? GA itself. Closes #34801 from sarutak/followup-SPARK-37286. Authored-by: Kousuke Saruta Signed-off-by: Sean Owen --- .../src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f96323ec16a3e..6d90432859d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -193,7 +193,7 @@ abstract class JdbcDialect extends Serializable with Logging{ /** * Converts aggregate function to String representing a SQL expression. - * @param aggregate The aggregate function to be converted. + * @param aggFunction The aggregate function to be converted. * @return Converted value. */ @Since("3.3.0") From a792696709ea27db36d59c1b69400efaa6f5a527 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 11 Nov 2021 17:19:21 +0800 Subject: [PATCH 07/53] [SPARK-37262][SQL] Don't log empty aggregate and group by in JDBCScan ### What changes were proposed in this pull request? Currently, the empty pushed aggregate and pushed group by are logged in Explain for JDBCScan ``` Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$172e75786 [NAME#1,SALARY#2] PushedAggregates: [], PushedFilters: [IsNotNull(SALARY), GreaterThan(SALARY,100.00)], PushedGroupby: [], ReadSchema: struct ``` After the fix, the JDBCSScan will be ``` Scan org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCScan$$anon$172e75786 [NAME#1,SALARY#2] PushedFilters: [IsNotNull(SALARY), GreaterThan(SALARY,100.00)], ReadSchema: struct ``` ### Why are the changes needed? address this comment https://github.com/apache/spark/pull/34451#discussion_r740220800 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #34540 from huaxingao/aggExplain. Authored-by: Huaxin Gao Signed-off-by: Wenchen Fan --- .../sql/execution/DataSourceScanExec.scala | 14 +++------- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 26 +++++++++---------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 3e6f0fd222ecb..18ad5b81560e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -134,13 +134,6 @@ case class RowDataSourceScanExec( def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") - val (aggString, groupByString) = if (pushedDownOperators.aggregation.nonEmpty) { - (seqToString(pushedDownOperators.aggregation.get.aggregateExpressions), - seqToString(pushedDownOperators.aggregation.get.groupByColumns)) - } else { - ("[]", "[]") - } - val markedFilters = if (filters.nonEmpty) { for (filter <- filters) yield { if (handledFilters.contains(filter)) s"*$filter" else s"$filter" @@ -151,9 +144,10 @@ case class RowDataSourceScanExec( Map( "ReadSchema" -> requiredSchema.catalogString, - "PushedFilters" -> seqToString(markedFilters.toSeq), - "PushedAggregates" -> aggString, - "PushedGroupby" -> groupByString) ++ + "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ + pushedDownOperators.aggregation.fold(Map[String, String]()) { v => + Map("PushedAggregates" -> seqToString(v.aggregateExpressions), + "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 3c95fa2e66de7..39b3c19ac1db8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -328,7 +328,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupby: [DEPT]" + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200))) @@ -345,7 +345,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [MAX(ID), MIN(ID)], " + "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + - "PushedGroupby: []" + "PushedGroupByColumns: []" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(2, 1))) @@ -424,7 +424,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [], " + - "PushedGroupby: [DEPT]" + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) @@ -437,7 +437,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [SUM(DISTINCT SALARY)], " + "PushedFilters: [], " + - "PushedGroupby: [DEPT]" + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) @@ -455,7 +455,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupby: [DEPT, NAME]" + "PushedGroupByColumns: [DEPT, NAME]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), @@ -474,7 +474,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupby: [DEPT]" + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) @@ -489,7 +489,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [MIN(SALARY)], " + "PushedFilters: [], " + - "PushedGroupby: [DEPT]" + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) @@ -512,7 +512,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupby: [DEPT]" + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) @@ -525,7 +525,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), SUM(BONUS)" + "PushedAggregates: [SUM(SALARY), SUM(BONUS)]" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(47100.0))) @@ -536,10 +536,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") df2.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: []" // aggregate over alias not push down - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.aggregation.isEmpty) + } } checkAnswer(df2, Seq(Row(53000.00))) } From 9eae482d0e2bd6f52120c02a612b654ee845fb37 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 16 Dec 2021 22:28:31 +0800 Subject: [PATCH 08/53] [SPARK-37483][SQL] Support push down top N to JDBC data source V2 ### What changes were proposed in this pull request? Currently, Spark supports push down limit to data source. However, in the user's scenario, limit must have the premise of order by. Because limit and order by are more valuable together. On the other hand, push down top N(same as order by ... limit N) outputs the data with basic order to Spark sort, the the sort of Spark may have some performance improvement. ### Why are the changes needed? 1. push down top N is very useful for users scenario. 2. push down top N could improves the performance of sort. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the physical execute. ### How was this patch tested? New tests. Closes #34918 from beliefer/SPARK-37483. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../connector/read/SupportsPushDownLimit.java | 2 +- .../connector/read/SupportsPushDownTopN.java | 38 ++++++ .../sql/execution/DataSourceScanExec.scala | 12 +- .../datasources/DataSourceStrategy.scala | 26 +++- .../execution/datasources/jdbc/JDBCRDD.scala | 25 +++- .../datasources/jdbc/JDBCRelation.scala | 7 +- .../datasources/v2/PushDownUtils.scala | 15 ++- .../datasources/v2/PushedDownOperators.scala | 6 +- .../v2/V2ScanRelationPushDown.scala | 49 ++++++-- .../datasources/v2/jdbc/JDBCScan.scala | 8 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 19 ++- .../apache/spark/sql/jdbc/DerbyDialect.scala | 6 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 5 - .../spark/sql/jdbc/MsSqlServerDialect.scala | 5 +- .../spark/sql/jdbc/TeradataDialect.scala | 5 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 116 ++++++++++++++---- 16 files changed, 276 insertions(+), 68 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index 7e50bf14d7817..fa6447bc068d5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to * push down LIMIT. Please note that the combination of LIMIT with other operations * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java new file mode 100644 index 0000000000000..0212895fde079 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.SortOrder; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N + * with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. + * is NOT pushed down. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownTopN extends ScanBuilder { + + /** + * Pushes down top N to the data source. + */ + boolean pushTopN(SortOrder[] orders, int limit); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 18ad5b81560e1..8bc18ef253f5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -142,13 +142,23 @@ case class RowDataSourceScanExec( handledFilters } + val topNOrLimitInfo = + if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) { + val pushedTopN = + s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" + + s" LIMIT ${pushedDownOperators.limit.get}" + Some("pushedTopN" -> pushedTopN) + } else { + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") + } + Map( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => Map("PushedAggregates" -> seqToString(v.aggregateExpressions), "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ - pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ + topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e8fb9ca3e46c3..84df3f8dd5b65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -336,7 +336,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -410,7 +410,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -433,7 +433,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -723,6 +723,24 @@ object DataSourceStrategy } } + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { + def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { + case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => + val directionV2 = directionV1 match { + case Ascending => SortDirection.ASCENDING + case Descending => SortDirection.DESCENDING + } + val nullOrderingV2 = nullOrderingV1 match { + case NullsFirst => NullOrdering.NULLS_FIRST + case NullsLast => NullOrdering.NULLS_LAST + } + Some(SortValue(FieldReference(name), directionV2, nullOrderingV2)) + case _ => None + } + + sortOrders.flatMap(translateOortOrder) + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 394ba3f8bb8c2..baee53847a5a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -151,12 +152,14 @@ object JDBCRDD extends Logging { * @param options - JDBC options that contains url, table and other information. * @param outputSchema - The schema of the columns or aggregate columns to SELECT. * @param groupByColumns - The pushed down group by columns. + * @param sample - The pushed down tableSample. * @param limit - The pushed down limit. If the value is 0, it means no limit or limit * is not pushed down. - * @param sample - The pushed down tableSample. + * @param sortValues - The sort values cooperates with limit to realize top N. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ + // scalastyle:off argcount def scanTable( sc: SparkContext, schema: StructType, @@ -167,7 +170,8 @@ object JDBCRDD extends Logging { outputSchema: Option[StructType] = None, groupByColumns: Option[Array[String]] = None, sample: Option[TableSampleInfo] = None, - limit: Int = 0): RDD[InternalRow] = { + limit: Int = 0, + sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -187,8 +191,10 @@ object JDBCRDD extends Logging { options, groupByColumns, sample, - limit) + limit, + sortOrders) } + // scalastyle:on argcount } /** @@ -207,7 +213,8 @@ private[jdbc] class JDBCRDD( options: JDBCOptions, groupByColumns: Option[Array[String]], sample: Option[TableSampleInfo], - limit: Int) + limit: Int, + sortOrders: Array[SortOrder]) extends RDD[InternalRow](sc, Nil) { /** @@ -255,6 +262,14 @@ private[jdbc] class JDBCRDD( } } + private def getOrderByClause: String = { + if (sortOrders.nonEmpty) { + s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}" + } else { + "" + } + } + /** * Runs the SQL query against the JDBC driver. * @@ -339,7 +354,7 @@ private[jdbc] class JDBCRDD( val myLimitClause: String = dialect.getLimitClause(limit) val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + - s" $myWhereClause $getGroupByClause $myLimitClause" + s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index cd1eae89ee890..ecb207363cd59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf @@ -301,7 +302,8 @@ private[sql] case class JDBCRelation( filters: Array[Filter], groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], - limit: Int): RDD[Row] = { + limit: Int, + sortOrders: Array[SortOrder]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -313,7 +315,8 @@ private[sql] case class JDBCRelation( Some(finalSchema), groupByColumns, tableSample, - limit).asInstanceOf[RDD[Row]] + limit, + sortOrders).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index a98b8979d3e3f..2b26eee45221d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -22,10 +22,10 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -158,6 +158,17 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down top N to the data source Scan + */ + def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = { + scanBuilder match { + case s: SupportsPushDownTopN => + s.pushTopN(order, limit) + case _ => false + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala index c21354d646164..20ced9c17f7e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation /** @@ -25,4 +26,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation case class PushedDownOperators( aggregation: Option[Aggregation], sample: Option[TableSampleInfo], - limit: Option[Int]) + limit: Option[Int], + sortValues: Seq[SortOrder]) { + assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index f73f831903364..148864e8e4b3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -246,17 +247,39 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match { + case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit) + if (limitPushed) { + sHolder.pushedLimit = Some(limit) + } + operation + case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) + if filter.isEmpty => + val orders = DataSourceStrategy.translateSortOrders(order) + if (orders.length == order.length) { + val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) + if (topNPushed) { + sHolder.pushedLimit = Some(limit) + sHolder.sortOrders = orders + operation + } else { + s + } + } else { + s + } + case p: Project => + val newChild = pushDownLimit(p.child, limit) + p.withNewChildren(Seq(newChild)) + case other => other + } + def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => - child match { - case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => - val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue) - if (limitPushed) { - sHolder.pushedLimit = Some(limitValue) - } - globalLimit - case _ => globalLimit - } + val newChild = pushDownLimit(child, limitValue) + val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) + globalLimit.withNewChildren(Seq(newLocalLimit)) } private def getWrappedScan( @@ -270,8 +293,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - val pushedDownOperators = - PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit) + val pushedDownOperators = PushedDownOperators(aggregation, + sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortOrders) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } @@ -284,6 +307,8 @@ case class ScanBuilderHolder( builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None + var sortOrders: Seq[SortOrder] = Seq.empty[SortOrder] + var pushedSample: Option[TableSampleInfo] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index ff79d1a5c4144..87ec9f43804e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -31,7 +32,8 @@ case class JDBCScan( pushedAggregateColumn: Array[String] = Array(), groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], - pushedLimit: Int) extends V1Scan { + pushedLimit: Int, + sortOrders: Array[SortOrder]) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -46,8 +48,8 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan( - columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit) + relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, + pushedLimit, sortOrders) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index d3c141ed53c5c..1760122133d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,8 +20,9 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -39,6 +40,7 @@ case class JDBCScanBuilder( with SupportsPushDownAggregates with SupportsPushDownLimit with SupportsPushDownTableSample + with SupportsPushDownTopN with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis @@ -51,6 +53,8 @@ case class JDBCScanBuilder( private var pushedLimit = 0 + private var sortOrders: Array[SortOrder] = Array.empty[SortOrder] + override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) @@ -119,8 +123,17 @@ case class JDBCScanBuilder( } override def pushLimit(limit: Int): Boolean = { - if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) { + if (jdbcOptions.pushDownLimit) { + pushedLimit = limit + return true + } + false + } + + override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { + if (jdbcOptions.pushDownLimit) { pushedLimit = limit + sortOrders = orders return true } false @@ -151,6 +164,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit) + pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index ecb514abac01c..f19ef7ead5f8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -58,7 +58,7 @@ private object DerbyDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } - // ToDo: use fetch first n rows only for limit, e.g. - // select * from employee fetch first 10 rows only; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 6d90432859d71..5a445c5d56bdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -336,11 +336,6 @@ abstract class JdbcDialect extends Serializable with Logging{ if (limit > 0 ) s"LIMIT $limit" else "" } - /** - * returns whether the dialect supports limit or not - */ - def supportsLimit(): Boolean = true - def supportsTableSample: Boolean = false def getTableSample(sample: TableSampleInfo): String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8dad5ef8e1eae..8e5674a181e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -119,6 +119,7 @@ private object MsSqlServerDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } - // ToDo: use top n to get limit, e.g. select top 100 * from employee; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 2a776bdb7ab04..13f4c5fe9c926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -56,6 +56,7 @@ private case object TeradataDialect extends JdbcDialect { s"RENAME TABLE $oldTable TO $newTable" } - // ToDo: use top n to get limit, e.g. select top 100 * from employee; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 39b3c19ac1db8..5f10f2ef105d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,7 +23,8 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sort} +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{lit, sum, udf} @@ -102,14 +103,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df, true, 3) + checkPushedLimit(df, Some(3)) checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) } test("simple scan with LIMIT") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) - checkPushedLimit(df1, true, 1) + checkPushedLimit(df1, Some(1)) checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0))) val df2 = spark.read @@ -120,7 +121,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .filter($"dept" > 1) .limit(1) - checkPushedLimit(df2, true, 1) + checkPushedLimit(df2, Some(1)) checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") @@ -128,46 +129,117 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df3, true, 1) + checkPushedLimit(df3, Some(1)) checkAnswer(df3, Seq(Row("alex"))) val df4 = spark.read .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) - checkPushedLimit(df4, false, 0) + checkPushedLimit(df4, None) checkAnswer(df4, Seq(Row(1, 19000.00))) - val df5 = spark.read - .table("h2.test.employee") - .sort("SALARY") - .limit(1) - checkPushedLimit(df5, false, 0) - checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0))) - val name = udf { (x: String) => x.matches("cat|dav|amy") } val sub = udf { (x: String) => x.substring(0, 3) } - val df6 = spark.read + val df5 = spark.read .table("h2.test.employee") .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) .filter(name($"shortName")) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedLimit(df6, false, 0) - checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) + checkPushedLimit(df5, None) + checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = { + private def checkPushedLimit(df: DataFrame, limit: Option[Int] = None, + sortValues: Seq[SortValue] = Nil): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => - if (pushed) { - assert(v1.pushedDownOperators.limit === Some(limit)) - } else { - assert(v1.pushedDownOperators.limit.isEmpty) - } + assert(v1.pushedDownOperators.limit === limit) + assert(v1.pushedDownOperators.sortValues === sortValues) } } + if (sortValues.nonEmpty) { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + assert(sorts.isEmpty) + } + } + + test("simple scan with top N") { + val df1 = spark.read + .table("h2.test.employee") + .sort("salary") + .limit(1) + checkPushedLimit(df1, Some(1), createSortValues()) + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val df2 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary").limit(1) + checkPushedLimit(df2, Some(1), createSortValues()) + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val df3 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .orderBy($"salary".desc) + .limit(1) + checkPushedLimit( + df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) + checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0))) + + val df4 = + sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") + val scan = df4.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq("NAME"))) + checkPushedLimit(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) + checkAnswer(df4, Seq(Row("david"))) + + val df5 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary") + checkPushedLimit(df5, None) + checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) + + val df6 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + .limit(1) + checkPushedLimit(df6) + checkAnswer(df6, Seq(Row(1, 19000.00))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df7 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .sort($"SALARY".desc) + .limit(1) + // LIMIT is pushed down only if all the filters are pushed down + checkPushedLimit(df7) + checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + + val df8 = spark.read + .table("h2.test.employee") + .sort(sub($"NAME")) + .limit(1) + checkPushedLimit(df8) + checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0))) + } + + private def createSortValues( + sortDirection: SortDirection = SortDirection.ASCENDING, + nullOrdering: NullOrdering = NullOrdering.NULLS_FIRST): Seq[SortValue] = { + Seq(SortValue(FieldReference("salary"), sortDirection, nullOrdering)) } test("scan with filter push-down") { From abf7662cc950f7e55cc2a1f1d2540e85b309929d Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 23 Dec 2021 23:05:20 +0800 Subject: [PATCH 09/53] [SPARK-37644][SQL] Support datasource v2 complete aggregate pushdown ### What changes were proposed in this pull request? Currently , Spark supports push down aggregate with partial-agg and final-agg . For some data source (e.g. JDBC ) , we can avoid partial-agg and final-agg by running completely on database. ### Why are the changes needed? Improve performance for aggregate pushdown. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the inner implement. ### How was this patch tested? New tests. Closes #34904 from beliefer/SPARK-37644. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../read/SupportsPushDownAggregates.java | 8 ++ .../v2/V2ScanRelationPushDown.scala | 101 +++++++++++------- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 3 + .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 100 ++++++++++++++++- 4 files changed, 172 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 3e643b5493310..4e6c59e2881fb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -45,6 +45,14 @@ @Evolving public interface SupportsPushDownAggregates extends ScanBuilder { + /** + * Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg + * and final-agg when the aggregation operation can be pushed down to the datasource completely. + * + * @return true if the aggregation can be pushed down to datasource completely, false otherwise. + */ + default boolean supportCompletePushDown() { return false; } + /** * Pushes down Aggregation to datasource. The order of the datasource scan output columns should * be: grouping columns, aggregate columns (in the same order as the aggregate functions in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 148864e8e4b3b..ffb1187123844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation @@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { @@ -131,7 +131,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) case (_, b) => b } - val output = groupAttrs ++ newOutput.drop(groupAttrs.length) + val aggOutput = newOutput.drop(groupAttrs.length) + val output = groupAttrs ++ aggOutput logInfo( s""" @@ -147,40 +148,59 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - val plan = Aggregate( - output.take(groupingExpressions.length), resultExpressions, scanRelation) - - // scalastyle:off - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9, c2#10] ... - // - // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] - // we have the following - // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // scalastyle:on - val aggOutput = output.drop(groupAttrs.length) - plan.transformExpressions { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case max: aggregate.Max => max.copy(child = aggOutput(ordinal)) - case min: aggregate.Min => min.copy(child = aggOutput(ordinal)) - case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal)) - case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal)) - case other => other - } - agg.copy(aggregateFunction = aggFunction) + if (r.supportCompletePushDown()) { + val projectExpressions = resultExpressions.map { expr => + // TODO At present, only push down group by attribute is supported. + // In future, more attribute conversion is extended here. e.g. GetStructField + expr.transform { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val child = + addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) + Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) + } + }.asInstanceOf[Seq[NamedExpression]] + Project(projectExpressions, scanRelation) + } else { + val plan = Aggregate( + output.take(groupingExpressions.length), resultExpressions, scanRelation) + + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + plan.transformExpressions { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val aggAttribute = aggOutput(ordinal) + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case max: aggregate.Max => + max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType)) + case min: aggregate.Min => + min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType)) + case sum: aggregate.Sum => + sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType)) + case _: aggregate.Count => + aggregate.Sum(addCastIfNeeded(aggAttribute, LongType)) + case other => other + } + agg.copy(aggregateFunction = aggFunction) + } } } case _ => aggNode @@ -189,6 +209,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = + if (aggAttribute.dataType == aggDataType) { + aggAttribute + } else { + Cast(aggAttribute, aggDataType) + } + def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 1760122133d22..01722e883831f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -72,6 +72,9 @@ case class JDBCScanBuilder( private var pushedGroupByCols: Option[Array[String]] = None + override def supportCompletePushDown: Boolean = + jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) + override def pushAggregation(aggregation: Aggregation): Boolean = { if (!jdbcOptions.pushDownAggregate) return false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5f10f2ef105d0..c809551775e0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,7 +23,7 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -388,6 +388,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } + private def checkAggregateRemoved(df: DataFrame, removed: Boolean = true): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + if (removed) { + assert(aggregates.isEmpty) + } else { + assert(aggregates.nonEmpty) + } + } + test("scan with aggregate push-down: MAX MIN with filter and group by") { val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt") @@ -395,6 +406,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case f: Filter => f } assert(filters.isEmpty) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -412,6 +424,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case f: Filter => f } assert(filters.isEmpty) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -425,6 +438,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: aggregate + number") { val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -436,6 +450,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: COUNT(*)") { val df = sql("select COUNT(*) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -447,6 +462,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: COUNT(col)") { val df = sql("select COUNT(DEPT) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -458,6 +474,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: COUNT(DISTINCT col)") { val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -469,6 +486,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: SUM without filer and group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -480,6 +498,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: DISTINCT SUM without filer and group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -491,6 +510,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: SUM with group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -504,6 +524,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: DISTINCT SUM with group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -518,10 +539,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: with multiple group by columns") { val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DEPT, NAME") - val filters11 = df.queryExecution.optimizedPlan.collect { + val filters = df.queryExecution.optimizedPlan.collect { case f: Filter => f } - assert(filters11.isEmpty) + assert(filters.isEmpty) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -534,6 +556,60 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Row(10000, 1000), Row(12000, 1200))) } + test("scan with aggregate push-down: with concat multiple group key in project") { + val df1 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) FROM h2.test.employee" + + " where dept > 0 group by DEPT, NAME") + val filters1 = df1.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters1.isEmpty) + checkAggregateRemoved(df1) + df1.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [MAX(SALARY)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT, NAME]" + checkKeywordsExistsInExplain(df1, expected_plan_fragment) + } + checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), + Row("2#david", 10000), Row("6#jen", 12000))) + + val df2 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" + + " FROM h2.test.employee where dept > 0 group by DEPT, NAME") + val filters2 = df2.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters2.isEmpty) + checkAggregateRemoved(df2) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT, NAME]" + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), + Row("2#david", 11300), Row("6#jen", 13200))) + + val df3 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" + + " FROM h2.test.employee where dept > 0 group by concat_ws('#', DEPT, NAME)") + val filters3 = df3.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters3.isEmpty) + checkAggregateRemoved(df3, false) + df3.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + checkKeywordsExistsInExplain(df3, expected_plan_fragment) + } + checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), + Row("2#david", 11300), Row("6#jen", 13200))) + } + test("scan with aggregate push-down: with having clause") { val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DEPT having MIN(BONUS) > 1000") @@ -541,6 +617,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case f: Filter => f // filter over aggregate not push down } assert(filters.nonEmpty) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -556,6 +633,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("select * from h2.test.employee") .groupBy($"DEPT") .min("SALARY").as("total") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -579,6 +657,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case f: Filter => f } assert(filters.nonEmpty) // filter over aggregate not pushed down + checkAggregateRemoved(df) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -594,6 +673,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = spark.table("h2.test.employee") val decrease = udf { (x: Double, y: Double) => x - y } val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value")) + checkAggregateRemoved(query) query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = @@ -607,6 +687,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val cols = Seq("a", "b", "c", "d") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") + checkAggregateRemoved(df2, false) df2.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => @@ -615,4 +696,17 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } checkAnswer(df2, Seq(Row(53000.00))) } + + test("scan with aggregate push-down: SUM(CASE WHEN) with group by") { + val df = + sql("SELECT SUM(CASE WHEN SALARY > 0 THEN 1 ELSE 0 END) FROM h2.test.employee GROUP BY DEPT") + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [], " + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(1), Row(2), Row(2))) + } } From aad72addd8879c44acad651c80eac63d2206c1a1 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 13 Dec 2021 22:51:40 +0800 Subject: [PATCH 10/53] [SPARK-37627][SQL] Add sorted column in BucketTransform ### What changes were proposed in this pull request? In V1, we can create table with sorted bucket like the following: ``` sql("CREATE TABLE tbl(a INT, b INT) USING parquet " + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS") ``` However, creating table with sorted bucket in V2 failed with Exception `org.apache.spark.sql.AnalysisException: Cannot convert bucketing with sort columns to a transform.` ### Why are the changes needed? This PR adds sorted column in BucketTransform so we can create table in V2 with sorted bucket ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? new UT Closes #34879 from huaxingao/sortedBucket. Authored-by: Huaxin Gao Signed-off-by: Wenchen Fan --- .../catalog/CatalogV2Implicits.scala | 9 ++--- .../connector/expressions/expressions.scala | 36 ++++++++++++++----- .../sql/errors/QueryCompilationErrors.scala | 7 +--- .../sql/connector/catalog/InMemoryTable.scala | 2 +- .../expressions/TransformExtractorSuite.scala | 4 +-- .../datasources/v2/V2SessionCatalog.scala | 4 +-- .../sql/connector/DataSourceV2SQLSuite.scala | 18 ++++++++++ 7 files changed, 57 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 39642fd541706..185a1a2644e2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -38,12 +38,13 @@ private[sql] object CatalogV2Implicits { implicit class BucketSpecHelper(spec: BucketSpec) { def asTransform: BucketTransform = { + val references = spec.bucketColumnNames.map(col => reference(Seq(col))) if (spec.sortColumnNames.nonEmpty) { - throw QueryCompilationErrors.cannotConvertBucketWithSortColumnsToTransformError(spec) + val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col))) + bucket(spec.numBuckets, references.toArray, sortedCol.toArray) + } else { + bucket(spec.numBuckets, references.toArray) } - - val references = spec.bucketColumnNames.map(col => reference(Seq(col))) - bucket(spec.numBuckets, references.toArray) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 2863d94d198b2..e52654ac69c96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -45,6 +45,12 @@ private[sql] object LogicalExpressions { def bucket(numBuckets: Int, references: Array[NamedReference]): BucketTransform = BucketTransform(literal(numBuckets, IntegerType), references) + def bucket( + numBuckets: Int, + references: Array[NamedReference], + sortedCols: Array[NamedReference]): BucketTransform = + BucketTransform(literal(numBuckets, IntegerType), references, sortedCols) + def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference) def years(reference: NamedReference): YearsTransform = YearsTransform(reference) @@ -97,7 +103,8 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R private[sql] final case class BucketTransform( numBuckets: Literal[Int], - columns: Seq[NamedReference]) extends RewritableTransform { + columns: Seq[NamedReference], + sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { override val name: String = "bucket" @@ -107,7 +114,13 @@ private[sql] final case class BucketTransform( override def arguments: Array[Expression] = numBuckets +: columns.toArray - override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" + override def describe: String = + if (sortedColumns.nonEmpty) { + s"bucket(${arguments.map(_.describe).mkString(", ")}," + + s" ${sortedColumns.map(_.describe).mkString(", ")})" + } else { + s"bucket(${arguments.map(_.describe).mkString(", ")})" + } override def toString: String = describe @@ -117,11 +130,12 @@ private[sql] final case class BucketTransform( } private[sql] object BucketTransform { - def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { + def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = + expr match { case transform: Transform => transform match { - case BucketTransform(n, FieldReference(parts)) => - Some((n, FieldReference(parts))) + case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) => + Some((n, FieldReference(parts), FieldReference(sortCols))) case _ => None } @@ -129,11 +143,17 @@ private[sql] object BucketTransform { None } - def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match { + def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = + transform match { + case NamedTransform("bucket", Seq( + Lit(value: Int, IntegerType), + Ref(partCols: Seq[String]), + Ref(sortCols: Seq[String]))) => + Some((value, FieldReference(partCols), FieldReference(sortCols))) case NamedTransform("bucket", Seq( Lit(value: Int, IntegerType), - Ref(seq: Seq[String]))) => - Some((value, FieldReference(seq))) + Ref(partCols: Seq[String]))) => + Some((value, FieldReference(partCols), FieldReference(Seq.empty[String]))) case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e7af006ad7023..ef262a88b7ecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, ResolvedNamespace, ResolvedTable, ResolvedView, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, InvalidUDFClassException} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, InvalidUDFClassException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.JoinType @@ -1371,11 +1371,6 @@ object QueryCompilationErrors { new AnalysisException("Cannot use interval type in the table schema.") } - def cannotConvertBucketWithSortColumnsToTransformError(spec: BucketSpec): Throwable = { - new AnalysisException( - s"Cannot convert bucketing with sort columns to a transform: $spec") - } - def cannotConvertTransformsToPartitionColumnsError(nonIdTransforms: Seq[Transform]): Throwable = { new AnalysisException("Transforms cannot be converted to partition columns: " + nonIdTransforms.map(_.describe).mkString(", ")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 2f3c5a38538c8..e0604576a94bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -161,7 +161,7 @@ class InMemoryTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - case BucketTransform(numBuckets, ref) => + case BucketTransform(numBuckets, ref, _) => val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) val valueHashCode = if (value == null) 0 else value.hashCode ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index fbd6a886d011b..340d225f80fdb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -139,7 +139,7 @@ class TransformExtractorSuite extends SparkFunSuite { } bucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq)) => + case BucketTransform(numBuckets, FieldReference(seq), _) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) case _ => @@ -147,7 +147,7 @@ class TransformExtractorSuite extends SparkFunSuite { } transform("unknown", ref("a", "b")) match { - case BucketTransform(_, _) => + case BucketTransform(_, _, _) => fail("Matched unknown transform") case _ => // expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 33b8f22e3f88a..d4a981d2205da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -293,8 +293,8 @@ private[sql] object V2SessionCatalog { case IdentityTransform(FieldReference(Seq(col))) => identityCols += col - case BucketTransform(numBuckets, FieldReference(Seq(col))) => - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil)) + case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => + bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) case transform => throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index a326b82dbaf1e..7b941ab0d8f7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1609,6 +1609,24 @@ class DataSourceV2SQLSuite } } + test("create table using - with sorted bucket") { + val identifier = "testcat.table_name" + withTable(identifier) { + sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" + + s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS") + val table = getTableMetadata(identifier) + val describe = spark.sql(s"DESCRIBE $identifier") + val part1 = describe + .filter("col_name = 'Part 0'") + .select("data_type").head.getString(0) + assert(part1 === "c") + val part2 = describe + .filter("col_name = 'Part 1'") + .select("data_type").head.getString(0) + assert(part2 === "bucket(4, b, a)") + } + } + test("REFRESH TABLE: v2 table") { val t = "testcat.ns1.ns2.tbl" withTable(t) { From c6d90e87732d1bf80fc0c267f37fefc81316e22b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 4 Jan 2022 17:23:38 +0800 Subject: [PATCH 11/53] [SPARK-37789][SQL] Add a class to represent general aggregate functions in DS V2 ### What changes were proposed in this pull request? There are a lot of aggregate functions in SQL and it's a lot of work to add them one by one in the DS v2 API. This PR proposes to add a new `GeneralAggregateFunc` class to represent all the general SQL aggregate functions. Since it's general, Spark doesn't know its aggregation buffer and can only push down the aggregation to the source completely. As an example, this PR also translates `AVG` to `GeneralAggregateFunc` and pushes it to JDBC V2. ### Why are the changes needed? To add aggregate functions in DS v2 easier. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? JDBC v2 test Closes #35070 from cloud-fan/agg. Lead-authored-by: Wenchen Fan Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/connector/expressions/Expression.java | 2 +- .../expressions/aggregate/Count.java | 3 - .../expressions/aggregate/CountStar.java | 3 - .../aggregate/GeneralAggregateFunc.java | 66 +++++++++++++++++++ .../connector/expressions/aggregate/Max.java | 3 - .../connector/expressions/aggregate/Min.java | 3 - .../connector/expressions/aggregate/Sum.java | 3 - .../connector/expressions/filter/Filter.java | 3 - .../read/SupportsPushDownAggregates.java | 21 +++--- .../connector/expressions/expressions.scala | 20 ++---- .../expressions/TransformExtractorSuite.scala | 8 +-- .../sql/execution/DataSourceScanExec.scala | 4 +- .../datasources/DataSourceStrategy.scala | 6 +- .../v2/V2ScanRelationPushDown.scala | 16 +++-- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 2 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 8 ++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 33 +++++++--- 17 files changed, 137 insertions(+), 67 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java index 6540c91597582..9f6c0975ae0e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java @@ -29,5 +29,5 @@ public interface Expression { /** * Format the expression as a human readable SQL-like string. */ - String describe(); + default String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java index 1273886e297bf..1685770604a46 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java @@ -46,7 +46,4 @@ public String toString() { return "COUNT(" + column.describe() + ")"; } } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java index f566ad164b8ef..13801194b63cb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java @@ -32,7 +32,4 @@ public CountStar() { @Override public String toString() { return "COUNT(*)"; } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java new file mode 100644 index 0000000000000..e0d95cfaafbb0 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import java.util.Arrays; +import java.util.stream.Collectors; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * The general implementation of {@link AggregateFunc}, which contains the upper-cased function + * name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial + * aggregate with this function to the source, but can only push down the entire aggregate. + *

+ * The currently supported SQL aggregate functions: + *

    + *
  1. AVG(input1)
    Since 3.3.0
  2. + *
+ * + * @since 3.3.0 + */ +@Evolving +public final class GeneralAggregateFunc implements AggregateFunc { + private final String name; + private final boolean isDistinct; + private final NamedReference[] inputs; + + public String name() { return name; } + public boolean isDistinct() { return isDistinct; } + public NamedReference[] inputs() { return inputs; } + + public GeneralAggregateFunc(String name, boolean isDistinct, NamedReference[] inputs) { + this.name = name; + this.isDistinct = isDistinct; + this.inputs = inputs; + } + + @Override + public String toString() { + String inputsString = Arrays.stream(inputs) + .map(Expression::describe) + .collect(Collectors.joining(", ")); + if (isDistinct) { + return name + "(DISTINCT " + inputsString + ")"; + } else { + return name + "(" + inputsString + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java index ed07cc9e32187..5acdf14bf7e2f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java @@ -35,7 +35,4 @@ public final class Max implements AggregateFunc { @Override public String toString() { return "MAX(" + column.describe() + ")"; } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java index 2e761037746fb..824c607ea7df0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java @@ -35,7 +35,4 @@ public final class Min implements AggregateFunc { @Override public String toString() { return "MIN(" + column.describe() + ")"; } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java index 057ebd89f7a19..6b04dc38c2846 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java @@ -46,7 +46,4 @@ public String toString() { return "SUM(" + column.describe() + ")"; } } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java index aa1fa082dc92c..af87e76d2ff7d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java @@ -37,7 +37,4 @@ public abstract class Filter implements Expression, Serializable { * Returns list of columns that are referenced by this filter. */ public abstract NamedReference[] references(); - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 4e6c59e2881fb..1b178d7f2be74 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -22,18 +22,19 @@ /** * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to - * push down aggregates. Spark assumes that the data source can't fully complete the - * grouping work, and will group the data source output again. For queries like - * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate - * to the data source, the data source can still output data with duplicated keys, which is OK - * as Spark will do GROUP BY key again. The final query plan can be something like this: + * push down aggregates. + *

+ * If the data source can't fully complete the grouping work, then + * {@link #supportCompletePushDown()} should return false, and Spark will group the data source + * output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down + * the aggregate to the data source, the data source can still output data with duplicated keys, + * which is OK as Spark will do GROUP BY key again. The final query plan can be something like this: *

- *   Aggregate [key#1], [min(min(value)#2) AS m#3]
- *     +- RelationV2[key#1, min(value)#2]
+ *   Aggregate [key#1], [min(min_value#2) AS m#3]
+ *     +- RelationV2[key#1, min_value#2]
  * 
* Similarly, if there is no grouping expression, the data source can still output more than one * rows. - * *

* When pushing down operators, Spark pushes down filters to the data source first, then push down * aggregates or apply column pruning. Depends on data source implementation, aggregates may or @@ -46,8 +47,8 @@ public interface SupportsPushDownAggregates extends ScanBuilder { /** - * Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg - * and final-agg when the aggregation operation can be pushed down to the datasource completely. + * Whether the datasource support complete aggregation push-down. Spark will do grouping again + * if this method returns false. * * @return true if the aggregation can be pushed down to datasource completely, false otherwise. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index e52654ac69c96..e3eab6f6730f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -88,9 +88,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R override def arguments: Array[Expression] = Array(ref) - override def describe: String = name + "(" + reference.describe + ")" - - override def toString: String = describe + override def toString: String = name + "(" + reference.describe + ")" protected def withNewRef(ref: NamedReference): Transform @@ -114,7 +112,7 @@ private[sql] final case class BucketTransform( override def arguments: Array[Expression] = numBuckets +: columns.toArray - override def describe: String = + override def toString: String = if (sortedColumns.nonEmpty) { s"bucket(${arguments.map(_.describe).mkString(", ")}," + s" ${sortedColumns.map(_.describe).mkString(", ")})" @@ -122,8 +120,6 @@ private[sql] final case class BucketTransform( s"bucket(${arguments.map(_.describe).mkString(", ")})" } - override def toString: String = describe - override def withReferences(newReferences: Seq[NamedReference]): Transform = { this.copy(columns = newReferences) } @@ -169,9 +165,7 @@ private[sql] final case class ApplyTransform( arguments.collect { case named: NamedReference => named } } - override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})" - - override def toString: String = describe + override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})" } /** @@ -338,21 +332,19 @@ private[sql] object HoursTransform { } private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { - override def describe: String = { + override def toString: String = { if (dataType.isInstanceOf[StringType]) { s"'$value'" } else { s"$value" } } - override def toString: String = describe } private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def fieldNames: Array[String] = parts.toArray - override def describe: String = parts.quoted - override def toString: String = describe + override def toString: String = parts.quoted } private[sql] object FieldReference { @@ -366,7 +358,7 @@ private[sql] final case class SortValue( direction: SortDirection, nullOrdering: NullOrdering) extends SortOrder { - override def describe(): String = s"$expression $direction $nullOrdering" + override def toString(): String = s"$expression $direction $nullOrdering" } private[sql] object SortValue { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index 340d225f80fdb..b2371ce667ffc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -28,7 +28,7 @@ class TransformExtractorSuite extends SparkFunSuite { private def lit[T](literal: T): Literal[T] = new Literal[T] { override def value: T = literal override def dataType: DataType = catalyst.expressions.Literal(literal).dataType - override def describe: String = literal.toString + override def toString: String = literal.toString } /** @@ -36,7 +36,7 @@ class TransformExtractorSuite extends SparkFunSuite { */ private def ref(names: String*): NamedReference = new NamedReference { override def fieldNames: Array[String] = names.toArray - override def describe: String = names.mkString(".") + override def toString: String = names.mkString(".") } /** @@ -46,7 +46,7 @@ class TransformExtractorSuite extends SparkFunSuite { override def name: String = func override def references: Array[NamedReference] = Array(ref) override def arguments: Array[Expression] = Array(ref) - override def describe: String = ref.describe + override def toString: String = ref.describe } test("Identity extractor") { @@ -135,7 +135,7 @@ class TransformExtractorSuite extends SparkFunSuite { override def name: String = "bucket" override def references: Array[NamedReference] = Array(col) override def arguments: Array[Expression] = Array(lit(16), col) - override def describe: String = s"bucket(16, ${col.describe})" + override def toString: String = s"bucket(16, ${col.describe})" } bucketTransform match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 8bc18ef253f5c..ae444bf3aabf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -156,8 +156,8 @@ case class RowDataSourceScanExec( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => - Map("PushedAggregates" -> seqToString(v.aggregateExpressions), - "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ + Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), + "PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++ topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 84df3f8dd5b65..9b1d268f49c01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -714,8 +714,10 @@ object DataSourceStrategy Some(new Count(FieldReference(name), aggregates.isDistinct)) case _ => None } - case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => + case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => Some(new Sum(FieldReference(name), aggregates.isDistinct)) + case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name)))) case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index ffb1187123844..8ae0a4d8af025 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.SortOrder -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources @@ -109,6 +109,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { sHolder.builder, normalizedAggregates, normalizedGroupingExpressions) if (pushedAggregates.isEmpty) { aggNode // return original plan node + } else if (!supportPartialAggPushDown(pushedAggregates.get) && + !r.supportCompletePushDown()) { + aggNode // return original plan node } else { // No need to do column pruning because only the aggregate columns are used as // DataSourceV2ScanRelation output columns. All the other columns are not @@ -145,9 +148,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { """.stripMargin) val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) - val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - if (r.supportCompletePushDown()) { val projectExpressions = resultExpressions.map { expr => // TODO At present, only push down group by attribute is supported. @@ -209,6 +210,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def supportPartialAggPushDown(agg: Aggregation): Boolean = { + // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. + agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) + } + private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = if (aggAttribute.dataType == aggDataType) { aggAttribute @@ -256,7 +262,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform { case sample: Sample => sample.child match { - case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => + case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => val tableSample = TableSampleInfo( sample.lowerBound, sample.upperBound, @@ -282,7 +288,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } operation case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) - if filter.isEmpty => + if filter.isEmpty => val orders = DataSourceStrategy.translateSortOrders(order) if (orders.length == order.length) { val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 01722e883831f..2d01a3e6842b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -79,7 +79,7 @@ case class JDBCScanBuilder( if (!jdbcOptions.pushDownAggregate) return false val dialect = JdbcDialects.get(jdbcOptions.url) - val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate(_)) + val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate) if (compiledAggs.length != aggregation.aggregateExpressions.length) return false val groupByCols = aggregation.groupByColumns.map { col => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 5a445c5d56bdf..e516960bb6746 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -216,7 +216,11 @@ abstract class JdbcDialect extends Serializable with Logging{ val column = quoteIdentifier(sum.column.fieldNames.head) Some(s"SUM($distinct$column)") case _: CountStar => - Some(s"COUNT(*)") + Some("COUNT(*)") + case f: GeneralAggregateFunc if f.name() == "AVG" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"AVG($distinct${f.inputs().head})") case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index c809551775e0b..94709c27f3784 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -399,8 +399,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } - test("scan with aggregate push-down: MAX MIN with filter and group by") { - val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" + + test("scan with aggregate push-down: MAX AVG with filter and group by") { + val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt") val filters = df.queryExecution.optimizedPlan.collect { case f: Filter => f @@ -410,16 +410,16 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } - checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200))) + checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } - test("scan with aggregate push-down: MAX MIN with filter without group by") { - val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0") + test("scan with aggregate push-down: MAX AVG with filter without group by") { + val df = sql("select MAX(ID), AVG(ID) FROM h2.test.people where id > 0") val filters = df.queryExecution.optimizedPlan.collect { case f: Filter => f } @@ -428,12 +428,29 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [MAX(ID), MIN(ID)], " + + "PushedAggregates: [MAX(ID), AVG(ID)], " + "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + "PushedGroupByColumns: []" checkKeywordsExistsInExplain(df, expected_plan_fragment) } - checkAnswer(df, Seq(Row(2, 1))) + checkAnswer(df, Seq(Row(2, 1.0))) + } + + test("partitioned scan with aggregate push-down: complete push-down only") { + withTempView("v") { + spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .createTempView("v") + val df = sql("select AVG(SALARY) FROM v GROUP BY name") + // Partitioned JDBC Scan doesn't support complete aggregate push-down, and AVG requires + // complete push-down so aggregate is not pushed at the end. + checkAggregateRemoved(df, removed = false) + checkAnswer(df, Seq(Row(9000.0), Row(10000.0), Row(10000.0), Row(12000.0), Row(12000.0))) + } } test("scan with aggregate push-down: aggregate + number") { From 576b1fbeb7ebde31107bb94d4c08354545a0f630 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 5 Jan 2022 22:08:11 +0800 Subject: [PATCH 12/53] [SPARK-37644][SQL][FOLLOWUP] When partition column is same as group by key, pushing down aggregate completely ### What changes were proposed in this pull request? When JDBC option specifying the "partitionColumn" and it's the same as group by key, the aggregate push-down should be completely. ### Why are the changes needed? Improve the datasource v2 complete aggregate pushdown. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the inner implement. ### How was this patch tested? New tests. Closes #35052 from beliefer/SPARK-37644-followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../read/SupportsPushDownAggregates.java | 13 +++--- .../v2/V2ScanRelationPushDown.scala | 4 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 8 +++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 41 +++++++++++++++++++ 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 1b178d7f2be74..4d88ec19c897b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -25,10 +25,11 @@ * push down aggregates. *

* If the data source can't fully complete the grouping work, then - * {@link #supportCompletePushDown()} should return false, and Spark will group the data source - * output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down - * the aggregate to the data source, the data source can still output data with duplicated keys, - * which is OK as Spark will do GROUP BY key again. The final query plan can be something like this: + * {@link #supportCompletePushDown(Aggregation)} should return false, and Spark will group the data + * source output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after + * pushing down the aggregate to the data source, the data source can still output data with + * duplicated keys, which is OK as Spark will do GROUP BY key again. The final query plan can be + * something like this: *

  *   Aggregate [key#1], [min(min_value#2) AS m#3]
  *     +- RelationV2[key#1, min_value#2]
@@ -50,15 +51,17 @@ public interface SupportsPushDownAggregates extends ScanBuilder {
    * Whether the datasource support complete aggregation push-down. Spark will do grouping again
    * if this method returns false.
    *
+   * @param aggregation Aggregation in SQL statement.
    * @return true if the aggregation can be pushed down to datasource completely, false otherwise.
    */
-  default boolean supportCompletePushDown() { return false; }
+  default boolean supportCompletePushDown(Aggregation aggregation) { return false; }
 
   /**
    * Pushes down Aggregation to datasource. The order of the datasource scan output columns should
    * be: grouping columns, aggregate columns (in the same order as the aggregate functions in
    * the given Aggregation).
    *
+   * @param aggregation Aggregation in SQL statement.
    * @return true if the aggregation can be pushed down to datasource, false otherwise.
    */
   boolean pushAggregation(Aggregation aggregation);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 8ae0a4d8af025..67002e50e4680 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -110,7 +110,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
               if (pushedAggregates.isEmpty) {
                 aggNode // return original plan node
               } else if (!supportPartialAggPushDown(pushedAggregates.get) &&
-                !r.supportCompletePushDown()) {
+                !r.supportCompletePushDown(pushedAggregates.get)) {
                 aggNode // return original plan node
               } else {
                 // No need to do column pruning because only the aggregate columns are used as
@@ -149,7 +149,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 
                 val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
                 val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
-                if (r.supportCompletePushDown()) {
+                if (r.supportCompletePushDown(pushedAggregates.get)) {
                   val projectExpressions = resultExpressions.map { expr =>
                     // TODO At present, only push down group by attribute is supported.
                     // In future, more attribute conversion is extended here. e.g. GetStructField
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 2d01a3e6842b3..61bf729bc8fbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -72,8 +72,12 @@ case class JDBCScanBuilder(
 
   private var pushedGroupByCols: Option[Array[String]] = None
 
-  override def supportCompletePushDown: Boolean =
-    jdbcOptions.numPartitions.map(_ == 1).getOrElse(true)
+  override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
+    lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
+    jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
+      (aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
+        jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
+  }
 
   override def pushAggregation(aggregation: Aggregation): Boolean = {
     if (!jdbcOptions.pushDownAggregate) return false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 94709c27f3784..fab141e1a9f50 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -700,6 +700,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     checkAnswer(query, Seq(Row(47100.0)))
   }
 
+  test("scan with aggregate push-down: partition columns are same as group by columns") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy($"dept")
+      .count()
+    checkAggregateRemoved(df)
+    checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1)))
+  }
+
   test("scan with aggregate push-down: aggregate over alias NOT push down") {
     val cols = Seq("a", "b", "c", "d")
     val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
@@ -726,4 +739,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     }
     checkAnswer(df, Seq(Row(1), Row(2), Row(2)))
   }
+
+  test("scan with aggregate push-down: partition columns with multi group by columns") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy($"dept", $"name")
+      .count()
+    checkAggregateRemoved(df, false)
+    checkAnswer(df, Seq(Row(1, "amy", 1), Row(1, "cathy", 1),
+      Row(2, "alex", 1), Row(2, "david", 1), Row(6, "jen", 1)))
+  }
+
+  test("scan with aggregate push-down: partition columns is different from group by columns") {
+    val df = spark.read
+      .option("partitionColumn", "dept")
+      .option("lowerBound", "0")
+      .option("upperBound", "2")
+      .option("numPartitions", "2")
+      .table("h2.test.employee")
+      .groupBy($"name")
+      .count()
+    checkAggregateRemoved(df, false)
+    checkAnswer(df,
+      Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1)))
+  }
 }

From cc970f156b973303987a85406851c4e36116fb6f Mon Sep 17 00:00:00 2001
From: Jiaan Geng 
Date: Thu, 6 Jan 2022 14:17:15 -0800
Subject: [PATCH 13/53] [SPARK-37527][SQL] Translate more standard aggregate
 functions for pushdown

### What changes were proposed in this pull request?
Currently, Spark aggregate pushdown will translate some standard aggregate functions, so that compile these functions to adapt specify database.
After this job, users could override `JdbcDialect.compileAggregate` to implement some standard aggregate functions supported by some database.
This PR just translate the ANSI standard aggregate functions. The mainstream database supports these functions show below:
| Name | ClickHouse | Presto | Teradata | Snowflake | Oracle | Postgresql | Vertica | MySQL | RedShift | ElasticSearch | Impala | Druid | SyBase | DB2 | H2 | Exasol | Mariadb | Phoenix | Yellowbrick | Singlestore | Influxdata | Dolphindb | Intersystems |
|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|
| `VAR_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | No | Yes | Yes |
| `VAR_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |  Yes | Yes | Yes | No | Yes | Yes | No | Yes | Yes |
| `STDDEV_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes |
| `STDDEV_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |  Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes |
| `COVAR_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No |  Yes | Yes | No | No | No | No | Yes | Yes | No |
| `COVAR_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No |  Yes | Yes | No | No | No | No | No | No | No |
| `CORR` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No |  Yes | Yes | No | No | No | No | No | Yes | No |

Because some aggregate functions will be converted by Optimizer show below, this PR no need to match them.

|Input|Parsed|Optimized|
|------|--------------------|----------|
|`Every`| `aggregate.BoolAnd` |`Min`|
|`Any`| `aggregate.BoolOr` |`Max`|
|`Some`| `aggregate.BoolOr` |`Max`|

### Why are the changes needed?
Make the implement of `*Dialect` could extends the aggregate functions by override `JdbcDialect.compileAggregate`.

### Does this PR introduce _any_ user-facing change?
Yes. Users could pushdown more aggregate functions.

### How was this patch tested?
Exists tests.

Closes #35101 from beliefer/SPARK-37527-new2.

Authored-by: Jiaan Geng 
Signed-off-by: Huaxin Gao 
---
 .../aggregate/GeneralAggregateFunc.java       |  7 +++
 .../datasources/DataSourceStrategy.scala      | 21 +++++++
 .../org/apache/spark/sql/jdbc/H2Dialect.scala | 25 ++++++++
 .../apache/spark/sql/jdbc/JDBCV2Suite.scala   | 60 +++++++++++++++++++
 4 files changed, 113 insertions(+)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index e0d95cfaafbb0..32615e201643b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -32,6 +32,13 @@
  * The currently supported SQL aggregate functions:
  * 
    *
  1. AVG(input1)
    Since 3.3.0
  2. + *
  3. VAR_POP(input1)
    Since 3.3.0
  4. + *
  5. VAR_SAMP(input1)
    Since 3.3.0
  6. + *
  7. STDDEV_POP(input1)
    Since 3.3.0
  8. + *
  9. STDDEV_SAMP(input1)
    Since 3.3.0
  10. + *
  11. COVAR_POP(input1, input2)
    Since 3.3.0
  12. + *
  13. COVAR_SAMP(input1, input2)
    Since 3.3.0
  14. + *
  15. CORR(input1, input2)
    Since 3.3.0
  16. *
* * @since 3.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9b1d268f49c01..990a00ca918fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -718,6 +718,27 @@ object DataSourceStrategy Some(new Sum(FieldReference(name), aggregates.isDistinct)) case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name)))) + case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("VAR_POP", aggregates.isDistinct, Array(FieldReference(name)))) + case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("VAR_SAMP", aggregates.isDistinct, Array(FieldReference(name)))) + case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("STDDEV_POP", aggregates.isDistinct, Array(FieldReference(name)))) + case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc("STDDEV_SAMP", aggregates.isDistinct, Array(FieldReference(name)))) + case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("COVAR_POP", aggregates.isDistinct, + Array(FieldReference(left), FieldReference(right)))) + case aggregate.CovSample(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("COVAR_SAMP", aggregates.isDistinct, + Array(FieldReference(left), FieldReference(right)))) + case aggregate.Corr(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("CORR", aggregates.isDistinct, + Array(FieldReference(left), FieldReference(right)))) + case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 9c727957ffab8..087c3573fbdbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -22,11 +22,36 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def classifyException(message: String, e: Throwable): AnalysisException = { if (e.isInstanceOf[SQLException]) { // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index fab141e1a9f50..a84d06b33534c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -713,6 +713,66 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1))) } + test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") { + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" + + " group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + } + + test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") { + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + + " where dept > 0 group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) + } + + test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { + val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + + " FROM h2.test.employee where dept > 0 group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + } + + test("scan with aggregate push-down: CORR with filter and group by") { + val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" + + " group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) + } + test("scan with aggregate push-down: aggregate over alias NOT push down") { val cols = Seq("a", "b", "c", "d") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) From 470371c68266b7f069f3d043ca2c6570bdf1084e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 7 Jan 2022 12:59:40 +0800 Subject: [PATCH 14/53] [SPARK-37734][SQL][TESTS] Upgrade h2 from 1.4.195 to 2.0.204 ### What changes were proposed in this pull request? This PR aims to upgrade `com.h2database` from 1.4.195 to 2.0.202 ### Why are the changes needed? Fix one vulnerability, ref: https://www.tenable.com/cve/CVE-2021-23463 ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? Jenkins test. Closes #35013 from beliefer/SPARK-37734. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- sql/core/pom.xml | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 31 +++++-------------- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 2 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 21 ++++++++----- 4 files changed, 23 insertions(+), 33 deletions(-) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 85bb234cf9a97..9afd9a3ef54b5 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -136,7 +136,7 @@ com.h2database h2 - 1.4.195 + 2.0.204 test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 8842db2a2aca4..3cb91b8b00190 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -24,7 +24,6 @@ import java.util.{Calendar, GregorianCalendar, Properties, TimeZone} import scala.collection.JavaConverters._ -import org.h2.jdbc.JdbcSQLException import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} @@ -54,7 +53,8 @@ class JDBCSuite extends QueryTest val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null - val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) + val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) ++ + Array.fill(15)(0.toByte) val testH2Dialect = new JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2") @@ -87,7 +87,6 @@ class JDBCSuite extends QueryTest val properties = new Properties() properties.setProperty("user", "testUser") properties.setProperty("password", "testPass") - properties.setProperty("rowId", "false") conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() @@ -162,7 +161,7 @@ class JDBCSuite extends QueryTest |OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)" + conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP(7))" ).executeUpdate() conn.prepareStatement("insert into test.timetypes values ('12:34:56', " + "'1996-01-01', '2002-02-20 11:22:33.543543543')").executeUpdate() @@ -177,12 +176,12 @@ class JDBCSuite extends QueryTest """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("CREATE TABLE test.timezone (tz TIMESTAMP WITH TIME ZONE) " + - "AS SELECT '1999-01-08 04:05:06.543543543 GMT-08:00'") + "AS SELECT '1999-01-08 04:05:06.543543543-08:00'") .executeUpdate() conn.commit() - conn.prepareStatement("CREATE TABLE test.array (ar ARRAY) " + - "AS SELECT '(1, 2, 3)'") + conn.prepareStatement("CREATE TABLE test.array_table (ar Integer ARRAY) " + + "AS SELECT ARRAY[1, 2, 3]") .executeUpdate() conn.commit() @@ -638,7 +637,7 @@ class JDBCSuite extends QueryTest assert(rows(0).getAs[Array[Byte]](0).sameElements(testBytes)) assert(rows(0).getString(1).equals("Sensitive")) assert(rows(0).getString(2).equals("Insensitive")) - assert(rows(0).getString(3).equals("Twenty-byte CHAR")) + assert(rows(0).getString(3).equals("Twenty-byte CHAR ")) assert(rows(0).getAs[Array[Byte]](4).sameElements(testBytes)) assert(rows(0).getString(5).equals("I am a clob!")) } @@ -729,20 +728,6 @@ class JDBCSuite extends QueryTest assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } - test("Pass extra properties via OPTIONS") { - // We set rowId to false during setup, which means that _ROWID_ column should be absent from - // all tables. If rowId is true (default), the query below doesn't throw an exception. - intercept[JdbcSQLException] { - sql( - s""" - |CREATE OR REPLACE TEMPORARY VIEW abc - |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)', - | user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) - } - } - test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties()) @@ -1375,7 +1360,7 @@ class JDBCSuite extends QueryTest }.getMessage assert(e.contains("Unsupported type TIMESTAMP_WITH_TIMEZONE")) e = intercept[SQLException] { - spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY", new Properties()).collect() + spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY_TABLE", new Properties()).collect() }.getMessage assert(e.contains("Unsupported type ARRAY")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a84d06b33534c..1c77f7a426da4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -433,7 +433,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "PushedGroupByColumns: []" checkKeywordsExistsInExplain(df, expected_plan_fragment) } - checkAnswer(df, Seq(Row(2, 1.0))) + checkAnswer(df, Seq(Row(2, 1.5))) } test("partitioned scan with aggregate push-down: complete push-down only") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index efa2773bfd692..79952e5a6c288 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -227,7 +227,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { JdbcDialects.registerDialect(testH2Dialect) val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val m = intercept[org.h2.jdbc.JdbcSQLException] { + val m = intercept[org.h2.jdbc.JdbcSQLSyntaxErrorException] { df.write.option("createTableOptions", "ENGINE tableEngineName") .jdbc(url1, "TEST.CREATETBLOPTS", properties) }.getMessage @@ -326,7 +326,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { test("save errors if wrong user/password combination") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val e = intercept[org.h2.jdbc.JdbcSQLException] { + val e = intercept[org.h2.jdbc.JdbcSQLInvalidAuthorizationSpecException] { df.write.format("jdbc") .option("dbtable", "TEST.SAVETEST") .option("url", url1) @@ -427,7 +427,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { // verify the data types of the created table by reading the database catalog of H2 val query = """ - |(SELECT column_name, type_name, character_maximum_length + |(SELECT column_name, data_type, character_maximum_length | FROM information_schema.columns WHERE table_name = 'DBCOLTYPETEST') """.stripMargin val rows = spark.read.jdbc(url1, query, properties).collect() @@ -436,7 +436,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { val typeName = row.getString(1) // For CHAR and VARCHAR, we also compare the max length if (typeName.contains("CHAR")) { - val charMaxLength = row.getInt(2) + val charMaxLength = row.getLong(2) assert(expectedTypes(row.getString(0)) == s"$typeName($charMaxLength)") } else { assert(expectedTypes(row.getString(0)) == typeName) @@ -452,15 +452,18 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { val df = spark.createDataFrame(sparkContext.parallelize(data), schema) // out-of-order - val expected1 = Map("id" -> "BIGINT", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + val expected1 = + Map("id" -> "BIGINT", "first#name" -> "CHARACTER VARYING(123)", "city" -> "CHARACTER(20)") testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), id BIGINT, city CHAR(20)", expected1) // partial schema - val expected2 = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + val expected2 = + Map("id" -> "INTEGER", "first#name" -> "CHARACTER VARYING(123)", "city" -> "CHARACTER(20)") testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), city CHAR(20)", expected2) withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { // should still respect the original column names - val expected = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CLOB") + val expected = Map("id" -> "INTEGER", "first#name" -> "CHARACTER VARYING(123)", + "city" -> "CHARACTER LARGE OBJECT(9223372036854775807)") testUserSpecifiedColTypes(df, "`FiRsT#NaMe` VARCHAR(123)", expected) } @@ -470,7 +473,9 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { StructField("First#Name", StringType) :: StructField("city", StringType) :: Nil) val df = spark.createDataFrame(sparkContext.parallelize(data), schema) - val expected = Map("id" -> "INTEGER", "First#Name" -> "VARCHAR(123)", "city" -> "CLOB") + val expected = + Map("id" -> "INTEGER", "First#Name" -> "CHARACTER VARYING(123)", + "city" -> "CHARACTER LARGE OBJECT(9223372036854775807)") testUserSpecifiedColTypes(df, "`First#Name` VARCHAR(123)", expected) } } From 6aeb2a5892d54157e37757e95b6c245f26a20fac Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 10 Jan 2022 21:47:55 +0800 Subject: [PATCH 15/53] [SPARK-37527][SQL] Compile `COVAR_POP`, `COVAR_SAMP` and `CORR` in `H2Dialet` ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/35101 translate `COVAR_POP`, `COVAR_SAMP` and `CORR`, but the H2 lower version cannot support them. After https://github.com/apache/spark/pull/35013, we can compile the three aggregate functions in `H2Dialet` now. ### Why are the changes needed? Supplement the implement of `H2Dialet`. ### Does this PR introduce _any_ user-facing change? 'Yes'. Spark could complete push-down `COVAR_POP`, `COVAR_SAMP` and `CORR` into H2. ### How was this patch tested? Test updated. Closes #35145 from beliefer/SPARK-37527_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/jdbc/H2Dialect.scala | 12 ++++++++++++ .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 12 ++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 087c3573fbdbf..1f422e5a59cf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -47,6 +47,18 @@ private object H2Dialect extends JdbcDialect { assert(f.inputs().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") case _ => None } ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 1c77f7a426da4..72dde8fa13222 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -749,11 +749,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + " FROM h2.test.employee where dept > 0 group by DePt") checkFiltersRemoved(df) - checkAggregateRemoved(df, false) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]" + "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) @@ -763,11 +765,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" + " group by DePt") checkFiltersRemoved(df) - checkAggregateRemoved(df, false) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]" + "PushedAggregates: [CORR(BONUS, BONUS)], " + + "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) From b5dc371f158ec6c76473014841bcffdd2adfc29e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 20 Jan 2022 12:13:00 +0800 Subject: [PATCH 16/53] [SPARK-37839][SQL] DS V2 supports partial aggregate push-down `AVG` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? `max`,`min`,`count`,`sum`,`avg` are the most commonly used aggregation functions. Currently, DS V2 supports complete aggregate push-down of `avg`. But, supports partial aggregate push-down of `avg` is very useful. The aggregate push-down algorithm is: 1. Spark translates group expressions of `Aggregate` to DS V2 `Aggregation`. 2. Spark calls `supportCompletePushDown` to check if it can completely push down aggregate. 3. If `supportCompletePushDown` returns true, we preserves the aggregate expressions as final aggregate expressions. Otherwise, we split `AVG` into 2 functions: `SUM` and `COUNT`. 4. Spark translates final aggregate expressions and group expressions of `Aggregate` to DS V2 `Aggregation` again, and pushes the `Aggregation` to JDBC source. 5. Spark constructs the final aggregate. ### Why are the changes needed? DS V2 supports partial aggregate push-down `AVG` ### Does this PR introduce _any_ user-facing change? 'Yes'. DS V2 could partial aggregate push-down `AVG` ### How was this patch tested? New tests. Closes #35130 from beliefer/SPARK-37839. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../connector/expressions/aggregate/Avg.java | 49 ++++++++ .../aggregate/GeneralAggregateFunc.java | 1 - .../expressions/aggregate/Average.scala | 2 +- .../datasources/DataSourceStrategy.scala | 29 ++++- .../datasources/v2/PushDownUtils.scala | 40 +------ .../v2/V2ScanRelationPushDown.scala | 108 ++++++++++++++---- .../apache/spark/sql/jdbc/JdbcDialects.scala | 11 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 82 ++++++++++++- 8 files changed, 250 insertions(+), 72 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java new file mode 100644 index 0000000000000..5e10ec9ee1644 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * An aggregate function that returns the mean of all the values in a group. + * + * @since 3.3.0 + */ +@Evolving +public final class Avg implements AggregateFunc { + private final NamedReference column; + private final boolean isDistinct; + + public Avg(NamedReference column, boolean isDistinct) { + this.column = column; + this.isDistinct = isDistinct; + } + + public NamedReference column() { return column; } + public boolean isDistinct() { return isDistinct; } + + @Override + public String toString() { + if (isDistinct) { + return "AVG(DISTINCT " + column.describe() + ")"; + } else { + return "AVG(" + column.describe() + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 32615e201643b..0ff26c8875b7a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -31,7 +31,6 @@ *

* The currently supported SQL aggregate functions: *

    - *
  1. AVG(input1)
    Since 3.3.0
  2. *
  3. VAR_POP(input1)
    Since 3.3.0
  4. *
  5. VAR_SAMP(input1)
    Since 3.3.0
  6. *
  7. STDDEV_POP(input1)
    Since 3.3.0
  8. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 9714a096a69a2..05f7edaeb5d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -69,7 +69,7 @@ case class Average( case _ => DoubleType } - private lazy val sumDataType = child.dataType match { + lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _: YearMonthIntervalType => YearMonthIntervalType() case _: DayTimeIntervalType => DayTimeIntervalType() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 990a00ca918fb..1934ef9f03228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -717,7 +717,7 @@ object DataSourceStrategy case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => Some(new Sum(FieldReference(name), aggregates.isDistinct)) case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name)))) + Some(new Avg(FieldReference(name), aggregates.isDistinct)) case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => Some(new GeneralAggregateFunc("VAR_POP", aggregates.isDistinct, Array(FieldReference(name)))) case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => @@ -746,6 +746,31 @@ object DataSourceStrategy } } + /** + * Translate aggregate expressions and group by expressions. + * + * @return translated aggregation. + */ + protected[sql] def translateAggregation( + aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { + + def columnAsString(e: Expression): Option[FieldReference] = e match { + case PushableColumnWithoutNestedColumn(name) => + Some(FieldReference.column(name).asInstanceOf[FieldReference]) + case _ => None + } + + val translatedAggregates = aggregates.flatMap(translateAggregate) + val translatedGroupBys = groupBy.flatMap(columnAsString) + + if (translatedAggregates.length != aggregates.length || + translatedGroupBys.length != groupBy.length) { + return None + } + + Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)) + } + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 2b26eee45221d..b54917e49ed3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -103,38 +101,6 @@ object PushDownUtils extends PredicateHelper { } } - /** - * Pushes down aggregates to the data source reader - * - * @return pushed aggregation. - */ - def pushAggregates( - scanBuilder: ScanBuilder, - aggregates: Seq[AggregateExpression], - groupBy: Seq[Expression]): Option[Aggregation] = { - - def columnAsString(e: Expression): Option[FieldReference] = e match { - case PushableColumnWithoutNestedColumn(name) => - Some(FieldReference(name).asInstanceOf[FieldReference]) - case _ => None - } - - scanBuilder match { - case r: SupportsPushDownAggregates if aggregates.nonEmpty => - val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate) - val translatedGroupBys = groupBy.flatMap(columnAsString) - - if (translatedAggregates.length != aggregates.length || - translatedGroupBys.length != groupBy.length) { - return None - } - - val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray) - Some(agg).filter(r.pushAggregation) - case _ => None - } - } - /** * Pushes down TableSample to the data source Scan */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 67002e50e4680..05857c545cdf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.SortOrder -import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { @@ -86,27 +86,68 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => sHolder.builder match { - case _: SupportsPushDownAggregates => + case r: SupportsPushDownAggregates => val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - var ordinal = 0 - val aggregates = resultExpressions.flatMap { expr => - expr.collect { - // Do not push down duplicated aggregate expressions. For example, - // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one - // `max(a)` to the data source. - case agg: AggregateExpression - if !aggExprToOutputOrdinal.contains(agg.canonicalized) => - aggExprToOutputOrdinal(agg.canonicalized) = ordinal - ordinal += 1 - agg - } - } + val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( groupingExpressions, sHolder.relation.output) - val pushedAggregates = PushDownUtils.pushAggregates( - sHolder.builder, normalizedAggregates, normalizedGroupingExpressions) + val translatedAggregates = DataSourceStrategy.translateAggregation( + normalizedAggregates, normalizedGroupingExpressions) + val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { + if (translatedAggregates.isEmpty || + r.supportCompletePushDown(translatedAggregates.get) || + translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { + (resultExpressions, aggregates, translatedAggregates) + } else { + // scalastyle:off + // The data source doesn't support the complete push-down of this aggregation. + // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be + // pushed, completely or partially. + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT avg(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // + // After convert avg(c1#9) to sum(c1#9)/count(c1#9) + // we have the following + // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // scalastyle:on + val newResultExpressions = resultExpressions.map { expr => + expr.transform { + case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => + val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) + val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) + // Closely follow `Average.evaluateExpression` + avg.dataType match { + case _: YearMonthIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) + case _: DayTimeIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) + case _ => + // TODO deal with the overflow issue + Divide(addCastIfNeeded(sum, avg.dataType), + addCastIfNeeded(count, avg.dataType), false) + } + } + }.asInstanceOf[Seq[NamedExpression]] + // Because aggregate expressions changed, translate them again. + aggExprToOutputOrdinal.clear() + val newAggregates = + collectAggregates(newResultExpressions, aggExprToOutputOrdinal) + val newNormalizedAggregates = DataSourceStrategy.normalizeExprs( + newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] + (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation( + newNormalizedAggregates, normalizedGroupingExpressions)) + } + } + + val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) if (pushedAggregates.isEmpty) { aggNode // return original plan node } else if (!supportPartialAggPushDown(pushedAggregates.get) && @@ -129,7 +170,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] // scalastyle:on val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + aggregates.length) + assert(newOutput.length == groupingExpressions.length + finalAggregates.length) val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) case (_, b) => b @@ -164,7 +205,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { Project(projectExpressions, scanRelation) } else { val plan = Aggregate( - output.take(groupingExpressions.length), resultExpressions, scanRelation) + output.take(groupingExpressions.length), finalResultExpressions, scanRelation) // scalastyle:off // Change the optimized logical plan to reflect the pushed down aggregate @@ -210,16 +251,33 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def collectAggregates(resultExpressions: Seq[NamedExpression], + aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { + var ordinal = 0 + resultExpressions.flatMap { expr => + expr.collect { + // Do not push down duplicated aggregate expressions. For example, + // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one + // `max(a)` to the data source. + case agg: AggregateExpression + if !aggExprToOutputOrdinal.contains(agg.canonicalized) => + aggExprToOutputOrdinal(agg.canonicalized) = ordinal + ordinal += 1 + agg + } + } + } + private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) } - private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = - if (aggAttribute.dataType == aggDataType) { - aggAttribute + private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) = + if (expression.dataType == expectedDataType) { + expression } else { - Cast(aggAttribute, aggDataType) + Cast(expression, expectedDataType) } def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e516960bb6746..7456b390c616e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -217,10 +217,11 @@ abstract class JdbcDialect extends Serializable with Logging{ Some(s"SUM($distinct$column)") case _: CountStar => Some("COUNT(*)") - case f: GeneralAggregateFunc if f.name() == "AVG" => - assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"AVG($distinct${f.inputs().head})") + case avg: Avg => + if (avg.column.fieldNames.length != 1) return None + val distinct = if (avg.isDistinct) "DISTINCT " else "" + val column = quoteIdentifier(avg.column.fieldNames.head) + Some(s"AVG($distinct$column)") case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 72dde8fa13222..637e01c260c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{lit, sum, udf} +import org.apache.spark.sql.functions.{avg, count, lit, sum, udf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -831,4 +831,84 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1))) } + + test("scan with aggregate push-down: complete push-down SUM, AVG, COUNT") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .groupBy($"name") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df2, Seq( + Row("alex", 12000.00, 12000.000000, 1), + Row("amy", 10000.00, 10000.000000, 1), + Row("cathy", 9000.00, 9000.000000, 1), + Row("david", 10000.00, 10000.000000, 1), + Row("jen", 12000.00, 12000.000000, 1))) + } + + test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"name") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + checkAnswer(df2, Seq( + Row("alex", 12000.00, 12000.000000, 1), + Row("amy", 10000.00, 10000.000000, 1), + Row("cathy", 9000.00, 9000.000000, 1), + Row("david", 10000.00, 10000.000000, 1), + Row("jen", 12000.00, 12000.000000, 1))) + } } From fd06d4468fcb983d489a870b62a402e32530cdbd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 29 Sep 2021 10:49:29 +0800 Subject: [PATCH 17/53] [SPARK-36526][SQL] DSV2 Index Support: Add supportsIndex interface ### What changes were proposed in this pull request? Indexes are database objects created on one or more columns of a table. Indexes are used to improve query performance. A detailed explanation of database index is here https://en.wikipedia.org/wiki/Database_index This PR adds `supportsIndex` interface that provides APIs to work with indexes. ### Why are the changes needed? Many data sources support index to improvement query performance. In order to take advantage of the index support in data source, this `supportsIndex` interface is added to let user to create/drop an index, list indexes, etc. ### Does this PR introduce _any_ user-facing change? yes, the following new APIs are added: - createIndex - dropIndex - indexExists - listIndexes New SQL syntax: ``` CREATE [index_type] INDEX [index_name] ON [TABLE] table_name (column_index_property_list)[OPTIONS indexPropertyList] column_index_property_list: column_name [OPTIONS(indexPropertyList)] [ , . . . ] indexPropertyList: index_property_name = index_property_value [ , . . . ] DROP INDEX index_name ``` ### How was this patch tested? only interface is added for now. Tests will be added when doing the implementation Closes #33754 from huaxingao/index_interface. Authored-by: Huaxin Gao Signed-off-by: Wenchen Fan --- .../catalog/index/SupportsIndex.java | 75 ++++++++++++++++++ .../connector/catalog/index/TableIndex.java | 79 +++++++++++++++++++ .../analysis/AlreadyExistException.scala | 3 + .../analysis/NoSuchItemException.scala | 3 + 4 files changed, 160 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java new file mode 100644 index 0000000000000..a8d55fb0b9c85 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.index; + +import java.util.Map; +import java.util.Properties; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.analysis.IndexAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.NoSuchIndexException; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * Table methods for working with index + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsIndex extends Table { + + /** + * Creates an index. + * + * @param indexName the name of the index to be created + * @param indexType the IndexType of the index to be created + * @param columns the columns on which index to be created + * @param columnProperties the properties of the columns on which index to be created + * @param properties the properties of the index to be created + * @throws IndexAlreadyExistsException If the index already exists (optional) + */ + void createIndex(String indexName, + String indexType, + NamedReference[] columns, + Map[] columnProperties, + Properties properties) + throws IndexAlreadyExistsException; + + /** + * Drops the index with the given name. + * + * @param indexName the name of the index to be dropped. + * @return true if the index is dropped + * @throws NoSuchIndexException If the index does not exist (optional) + */ + boolean dropIndex(String indexName) throws NoSuchIndexException; + + /** + * Checks whether an index exists in this table. + * + * @param indexName the name of the index + * @return true if the index exists, false otherwise + */ + boolean indexExists(String indexName); + + /** + * Lists all the indexes in this table. + */ + TableIndex[] listIndexes(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java new file mode 100644 index 0000000000000..99fce806a11b9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.index; + +import java.util.Collections; +import java.util.Map; +import java.util.Properties; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * Index in a table + * + * @since 3.3.0 + */ +@Evolving +public final class TableIndex { + private String indexName; + private String indexType; + private NamedReference[] columns; + private Map columnProperties = Collections.emptyMap(); + private Properties properties; + + public TableIndex( + String indexName, + String indexType, + NamedReference[] columns, + Map columnProperties, + Properties properties) { + this.indexName = indexName; + this.indexType = indexType; + this.columns = columns; + this.columnProperties = columnProperties; + this.properties = properties; + } + + /** + * @return the Index name. + */ + String indexName() { return indexName; } + + /** + * @return the indexType of this Index. + */ + String indexType() { return indexType; } + + /** + * @return the column(s) this Index is on. Could be multi columns (a multi-column index). + */ + NamedReference[] columns() { return columns; } + + /** + * @return the map of column and column property map. + */ + Map columnProperties() { return columnProperties; } + + /** + * Returns the index properties. + */ + Properties properties() { + return properties; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 70f821d5f8af0..ce48cfa89a389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -78,3 +78,6 @@ class PartitionsAlreadyExistException(message: String) extends AnalysisException class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") + +class IndexAlreadyExistsException(indexName: String, table: Identifier) + extends AnalysisException(s"Index '$indexName' already exists in table ${table.quoted}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index ba5a9c618c650..7a9f7b5c6bced 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -95,3 +95,6 @@ class NoSuchPartitionsException(message: String) extends AnalysisException(messa class NoSuchTempFunctionException(func: String) extends AnalysisException(s"Temporary function '$func' not found") + +class NoSuchIndexException(indexName: String) + extends AnalysisException(s"Index '$indexName' not found") From 52b36b0f4ed0e07129c6ab56d487dc3a0e73e755 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 8 Oct 2021 11:30:12 -0700 Subject: [PATCH 18/53] [SPARK-36913][SQL] Implement createIndex and IndexExists in DS V2 JDBC (MySQL dialect) ### What changes were proposed in this pull request? Implementing `createIndex`/`IndexExists` in DS V2 JDBC ### Why are the changes needed? This is a subtask of the V2 Index support. I am implementing index support for DS V2 JDBC so we can have a POC and an end to end testing. This PR implements `createIndex` and `IndexExists`. Next PR will implement `listIndexes` and `dropIndex`. I intentionally make the PR small so it's easier to review. Index is not supported by h2 database and create/drop index are not standard SQL syntax. This PR only implements `createIndex` and `IndexExists` in `MySQL` dialect. ### Does this PR introduce _any_ user-facing change? Yes, `createIndex`/`IndexExist` in DS V2 JDBC ### How was this patch tested? new test Closes #34164 from huaxingao/createIndexJDBC. Authored-by: Huaxin Gao Signed-off-by: Liang-Chi Hsieh --- .../sql/jdbc/v2/MySQLIntegrationSuite.scala | 33 +++++++++ .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 10 +++ .../catalog/index/SupportsIndex.java | 4 +- .../analysis/AlreadyExistException.scala | 4 +- .../datasources/jdbc/JdbcUtils.scala | 58 +++++++++++++++ .../datasources/v2/jdbc/JDBCTable.scala | 36 ++++++++- .../v2/jdbc/JDBCTableCatalog.scala | 55 +++++--------- .../apache/spark/sql/jdbc/JdbcDialects.scala | 41 +++++++++- .../apache/spark/sql/jdbc/MySQLDialect.scala | 74 ++++++++++++++++++- 9 files changed, 269 insertions(+), 46 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index db626dfdf8c39..3cb878774f2e9 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -18,11 +18,16 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.{Connection, SQLFeatureNotSupportedException} +import java.util import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.IndexAlreadyExistsException +import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} +import org.apache.spark.sql.connector.catalog.index.SupportsIndex +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} import org.apache.spark.sql.types._ @@ -115,4 +120,32 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def testIndex(tbl: String): Unit = { + val loaded = Catalogs.load("mysql", conf) + val jdbcTable = loaded.asInstanceOf[TableCatalog] + .loadTable(Identifier.of(Array.empty[String], "new_table")) + .asInstanceOf[SupportsIndex] + assert(jdbcTable.indexExists("i1") == false) + assert(jdbcTable.indexExists("i2") == false) + + val properties = new util.Properties(); + properties.put("KEY_BLOCK_SIZE", "10") + properties.put("COMMENT", "'this is a comment'") + jdbcTable.createIndex("i1", "", Array(FieldReference("col1")), + Array.empty[util.Map[NamedReference, util.Properties]], properties) + + jdbcTable.createIndex("i2", "", + Array(FieldReference("col2"), FieldReference("col3"), FieldReference("col5")), + Array.empty[util.Map[NamedReference, util.Properties]], new util.Properties) + + assert(jdbcTable.indexExists("i1") == true) + assert(jdbcTable.indexExists("i2") == true) + + val m = intercept[IndexAlreadyExistsException] { + jdbcTable.createIndex("i1", "", Array(FieldReference("col1")), + Array.empty[util.Map[NamedReference, util.Properties]], properties) + }.getMessage + assert(m.contains("Failed to create index: i1 in new_table")) + } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 406cc41521e1c..f176726fd0af0 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -219,6 +219,16 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu scan.schema.names.sameElements(Seq(col)) } + + def testIndex(tbl: String): Unit = {} + + test("SPARK-36913: Test INDEX") { + withTable(s"$catalogName.new_table") { + sql(s"CREATE TABLE $catalogName.new_table(col1 INT, col2 INT, col3 INT, col4 INT, col5 INT)") + testIndex(s"$catalogName.new_table") + } + } + test("SPARK-37038: Test TABLESAMPLE") { if (supportsTableSample) { withTable(s"$catalogName.new_table") { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java index a8d55fb0b9c85..24961e460cc26 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java @@ -40,14 +40,14 @@ public interface SupportsIndex extends Table { * @param indexName the name of the index to be created * @param indexType the IndexType of the index to be created * @param columns the columns on which index to be created - * @param columnProperties the properties of the columns on which index to be created + * @param columnsProperties the properties of the columns on which index to be created * @param properties the properties of the index to be created * @throws IndexAlreadyExistsException If the index already exists (optional) */ void createIndex(String indexName, String indexType, NamedReference[] columns, - Map[] columnProperties, + Map[] columnsProperties, Properties properties) throws IndexAlreadyExistsException; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index ce48cfa89a389..fb177251a7306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -79,5 +79,5 @@ class PartitionsAlreadyExistException(message: String) extends AnalysisException class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") -class IndexAlreadyExistsException(indexName: String, table: Identifier) - extends AnalysisException(s"Index '$indexName' already exists in table ${table.quoted}") +class IndexAlreadyExistsException(message: String, cause: Option[Throwable] = None) + extends AnalysisException(message, cause = cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 60fcaf94e1986..31c11568e35d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Driver, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} import java.time.{Instant, LocalDate} +import java.util import java.util.Locale import java.util.concurrent.TimeUnit @@ -37,6 +38,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.internal.SQLConf @@ -1019,6 +1021,35 @@ object JdbcUtils extends Logging { executeStatement(conn, options, s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)}") } + /** + * Create an index. + */ + def createIndex( + conn: Connection, + indexName: String, + indexType: String, + tableName: String, + columns: Array[NamedReference], + columnsProperties: Array[util.Map[NamedReference, util.Properties]], + properties: util.Properties, + options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) + executeStatement(conn, options, + dialect.createIndex(indexName, indexType, tableName, columns, columnsProperties, properties)) + } + + /** + * Check if an index exists + */ + def indexExists( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Boolean = { + val dialect = JdbcDialects.get(options.url) + dialect.indexExists(conn, indexName, tableName, options) + } + private def executeStatement(conn: Connection, options: JDBCOptions, sql: String): Unit = { val statement = conn.createStatement try { @@ -1028,4 +1059,31 @@ object JdbcUtils extends Logging { statement.close() } } + + def executeQuery(conn: Connection, options: JDBCOptions, sql: String): ResultSet = { + val statement = conn.createStatement + try { + statement.setQueryTimeout(options.queryTimeout) + statement.executeQuery(sql) + } finally { + statement.close() + } + } + + def classifyException[T](message: String, dialect: JdbcDialect)(f: => T): T = { + try { + f + } catch { + case e: Throwable => throw dialect.classifyException(message, e) + } + } + + def withConnection[T](options: JDBCOptions)(f: Connection => T): T = { + val conn = createConnectionFactory(options)() + try { + f(conn) + } finally { + conn.close() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala index 5e11ea66be4c6..957d021963a7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala @@ -23,13 +23,16 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JdbcUtils} +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOptions) - extends Table with SupportsRead with SupportsWrite { + extends Table with SupportsRead with SupportsWrite with SupportsIndex { override def name(): String = ident.toString @@ -48,4 +51,33 @@ case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOpt jdbcOptions.parameters.originalMap ++ info.options.asCaseSensitiveMap().asScala) JDBCWriteBuilder(schema, mergedOptions) } + + override def createIndex( + indexName: String, + indexType: String, + columns: Array[NamedReference], + columnsProperties: Array[util.Map[NamedReference, util.Properties]], + properties: util.Properties): Unit = { + JdbcUtils.withConnection(jdbcOptions) { conn => + JdbcUtils.classifyException(s"Failed to create index: $indexName in $name", + JdbcDialects.get(jdbcOptions.url)) { + JdbcUtils.createIndex( + conn, indexName, indexType, name, columns, columnsProperties, properties, jdbcOptions) + } + } + } + + override def indexExists(indexName: String): Boolean = { + JdbcUtils.withConnection(jdbcOptions) { conn => + JdbcUtils.indexExists(conn, indexName, name, jdbcOptions) + } + } + + override def dropIndex(indexName: String): Boolean = { + throw new UnsupportedOperationException("dropIndex is not supported yet") + } + + override def listIndexes(): Array[TableIndex] = { + throw new UnsupportedOperationException("listIndexes is not supported yet") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index a90ab564ddb50..566706486d3f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.datasources.v2.jdbc -import java.sql.{Connection, SQLException} +import java.sql.SQLException import java.util import scala.collection.JavaConverters._ @@ -57,7 +57,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def listTables(namespace: Array[String]): Array[Identifier] = { checkNamespace(namespace) - withConnection { conn => + JdbcUtils.withConnection(options) { conn => val schemaPattern = if (namespace.length == 1) namespace.head else null val rs = conn.getMetaData .getTables(null, schemaPattern, "%", Array("TABLE")); @@ -72,14 +72,14 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging checkNamespace(ident.namespace()) val writeOptions = new JdbcOptionsInWrite( options.parameters + (JDBCOptions.JDBC_TABLE_NAME -> getTableName(ident))) - classifyException(s"Failed table existence check: $ident") { - withConnection(JdbcUtils.tableExists(_, writeOptions)) + JdbcUtils.classifyException(s"Failed table existence check: $ident", dialect) { + JdbcUtils.withConnection(options)(JdbcUtils.tableExists(_, writeOptions)) } } override def dropTable(ident: Identifier): Boolean = { checkNamespace(ident.namespace()) - withConnection { conn => + JdbcUtils.withConnection(options) { conn => try { JdbcUtils.dropTable(conn, getTableName(ident), options) true @@ -91,8 +91,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { checkNamespace(oldIdent.namespace()) - withConnection { conn => - classifyException(s"Failed table renaming from $oldIdent to $newIdent") { + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed table renaming from $oldIdent to $newIdent", dialect) { JdbcUtils.renameTable(conn, getTableName(oldIdent), getTableName(newIdent), options) } } @@ -151,8 +151,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging val writeOptions = new JdbcOptionsInWrite(tableOptions) val caseSensitive = SQLConf.get.caseSensitiveAnalysis - withConnection { conn => - classifyException(s"Failed table creation: $ident") { + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed table creation: $ident", dialect) { JdbcUtils.createTable(conn, getTableName(ident), schema, caseSensitive, writeOptions) } } @@ -162,8 +162,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def alterTable(ident: Identifier, changes: TableChange*): Table = { checkNamespace(ident.namespace()) - withConnection { conn => - classifyException(s"Failed table altering: $ident") { + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed table altering: $ident", dialect) { JdbcUtils.alterTable(conn, getTableName(ident), changes, options) } loadTable(ident) @@ -172,7 +172,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def namespaceExists(namespace: Array[String]): Boolean = namespace match { case Array(db) => - withConnection { conn => + JdbcUtils.withConnection(options) { conn => val rs = conn.getMetaData.getSchemas(null, db) while (rs.next()) { if (rs.getString(1) == db) return true; @@ -183,7 +183,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } override def listNamespaces(): Array[Array[String]] = { - withConnection { conn => + JdbcUtils.withConnection(options) { conn => val schemaBuilder = ArrayBuilder.make[Array[String]] val rs = conn.getMetaData.getSchemas() while (rs.next()) { @@ -234,8 +234,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } } } - withConnection { conn => - classifyException(s"Failed create name space: $db") { + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed create name space: $db", dialect) { JdbcUtils.createNamespace(conn, options, db, comment) } } @@ -253,7 +253,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging changes.foreach { case set: NamespaceChange.SetProperty => if (set.property() == SupportsNamespaces.PROP_COMMENT) { - withConnection { conn => + JdbcUtils.withConnection(options) { conn => JdbcUtils.createNamespaceComment(conn, options, db, set.value) } } else { @@ -262,7 +262,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging case unset: NamespaceChange.RemoveProperty => if (unset.property() == SupportsNamespaces.PROP_COMMENT) { - withConnection { conn => + JdbcUtils.withConnection(options) { conn => JdbcUtils.removeNamespaceComment(conn, options, db) } } else { @@ -283,8 +283,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging if (listTables(Array(db)).nonEmpty) { throw QueryExecutionErrors.namespaceNotEmptyError(namespace) } - withConnection { conn => - classifyException(s"Failed drop name space: $db") { + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed drop name space: $db", dialect) { JdbcUtils.dropNamespace(conn, options, db) true } @@ -301,24 +301,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } } - private def withConnection[T](f: Connection => T): T = { - val conn = JdbcUtils.createConnectionFactory(options)() - try { - f(conn) - } finally { - conn.close() - } - } - private def getTableName(ident: Identifier): String = { (ident.namespace() :+ ident.name()).map(dialect.quoteIdentifier).mkString(".") } - - private def classifyException[T](message: String)(f: => T): T = { - try { - f - } catch { - case e: Throwable => throw dialect.classifyException(message, e) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 7456b390c616e..dcf9d0f0cfa52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Timestamp} import java.time.{Instant, LocalDate} +import java.util import scala.collection.mutable.ArrayBuilder @@ -30,9 +31,10 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -324,6 +326,43 @@ abstract class JdbcDialect extends Serializable with Logging{ s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS NULL" } + /** + * Creates an index. + * + * @param indexName the name of the index to be created + * @param indexType the type of the index to be created + * @param tableName the table on which index to be created + * @param columns the columns on which index to be created + * @param columnsProperties the properties of the columns on which index to be created + * @param properties the properties of the index to be created + */ + def createIndex( + indexName: String, + indexType: String, + tableName: String, + columns: Array[NamedReference], + columnsProperties: Array[util.Map[NamedReference, util.Properties]], + properties: util.Properties): String = { + throw new UnsupportedOperationException("createIndex is not supported") + } + + /** + * Checks whether an index exists + * + * @param indexName the name of the index + * @param tableName the table name on which index to be checked + * @param options JDBCOptions of the table + * @return true if the index with `indexName` exists in the table with `tableName`, + * false otherwise + */ + def indexExists( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Boolean = { + throw new UnsupportedOperationException("indexExists is not supported") + } + /** * Gets a dialect exception, classifies it and wraps it by `AnalysisException`. * @param message The error message to be placed to the returned exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index ed107707c9d1f..5c16ef6a947ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -17,14 +17,21 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Connection, SQLException, Types} +import java.util import java.util.Locale +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.IndexAlreadyExistsException +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} -private case object MySQLDialect extends JdbcDialect { +private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") @@ -102,4 +109,65 @@ private case object MySQLDialect extends JdbcDialect { case FloatType => Option(JdbcType("FLOAT", java.sql.Types.FLOAT)) case _ => JdbcUtils.getCommonJDBCType(dt) } + + // CREATE INDEX syntax + // https://dev.mysql.com/doc/refman/8.0/en/create-index.html + override def createIndex( + indexName: String, + indexType: String, + tableName: String, + columns: Array[NamedReference], + columnsProperties: Array[util.Map[NamedReference, util.Properties]], + properties: util.Properties): String = { + val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head)) + var indexProperties: String = "" + val scalaProps = properties.asScala + if (!properties.isEmpty) { + scalaProps.foreach { case (k, v) => + indexProperties = indexProperties + " " + s"$k $v" + } + } + + // columnsProperties doesn't apply to MySQL so it is ignored + s"CREATE $indexType INDEX ${quoteIdentifier(indexName)} ON" + + s" ${quoteIdentifier(tableName)}" + s" (${columnList.mkString(", ")}) $indexProperties" + } + + // SHOW INDEX syntax + // https://dev.mysql.com/doc/refman/8.0/en/show-index.html + override def indexExists( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Boolean = { + val sql = s"SHOW INDEXES FROM ${quoteIdentifier(tableName)}" + try { + val rs = JdbcUtils.executeQuery(conn, options, sql) + while (rs.next()) { + val retrievedIndexName = rs.getString("key_name") + if (conf.resolver(retrievedIndexName, indexName)) { + return true + } + } + false + } catch { + case _: Exception => + logWarning("Cannot retrieved index info.") + false + } + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + if (e.isInstanceOf[SQLException]) { + // Error codes are from + // https://mariadb.com/kb/en/mariadb-error-codes/#shared-mariadbmysql-error-codes + e.asInstanceOf[SQLException].getErrorCode match { + // ER_DUP_KEYNAME + case 1061 => + throw new IndexAlreadyExistsException(message, cause = Some(e)) + case _ => + } + } + super.classifyException(message, e) + } } From ce631103dcdc4fc2bd6e7815423638c6ed974657 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 12 Oct 2021 22:36:47 +0800 Subject: [PATCH 19/53] [SPARK-36914][SQL] Implement dropIndex and listIndexes in JDBC (MySQL dialect) ### What changes were proposed in this pull request? This PR implements `dropIndex` and `listIndexes` in MySQL dialect ### Why are the changes needed? As a subtask of the V2 Index support, this PR completes the implementation for JDBC V2 index support. ### Does this PR introduce _any_ user-facing change? Yes, `dropIndex/listIndexes` in DS V2 JDBC ### How was this patch tested? new tests Closes #34236 from huaxingao/listIndexJDBC. Authored-by: Huaxin Gao Signed-off-by: Wenchen Fan --- .../sql/jdbc/v2/MySQLIntegrationSuite.scala | 33 ++---- .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 110 +++++++++++++++--- .../catalog/index/SupportsIndex.java | 7 +- .../connector/catalog/index/TableIndex.java | 12 +- .../analysis/NoSuchItemException.scala | 4 +- .../datasources/jdbc/JdbcUtils.scala | 24 ++++ .../datasources/v2/jdbc/JDBCTable.scala | 13 ++- .../apache/spark/sql/jdbc/JdbcDialects.scala | 25 +++- .../apache/spark/sql/jdbc/MySQLDialect.scala | 84 ++++++++++--- 9 files changed, 245 insertions(+), 67 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 3cb878774f2e9..67e81087f8fb0 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -24,8 +24,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.IndexAlreadyExistsException -import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -121,31 +119,22 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { assert(t.schema === expectedSchema) } - override def testIndex(tbl: String): Unit = { - val loaded = Catalogs.load("mysql", conf) - val jdbcTable = loaded.asInstanceOf[TableCatalog] - .loadTable(Identifier.of(Array.empty[String], "new_table")) - .asInstanceOf[SupportsIndex] - assert(jdbcTable.indexExists("i1") == false) - assert(jdbcTable.indexExists("i2") == false) + override def supportsIndex: Boolean = true + override def testIndexProperties(jdbcTable: SupportsIndex): Unit = { val properties = new util.Properties(); properties.put("KEY_BLOCK_SIZE", "10") properties.put("COMMENT", "'this is a comment'") - jdbcTable.createIndex("i1", "", Array(FieldReference("col1")), + // MySQL doesn't allow property set on individual column, so use empty Array for + // column properties + jdbcTable.createIndex("i1", "BTREE", Array(FieldReference("col1")), Array.empty[util.Map[NamedReference, util.Properties]], properties) - jdbcTable.createIndex("i2", "", - Array(FieldReference("col2"), FieldReference("col3"), FieldReference("col5")), - Array.empty[util.Map[NamedReference, util.Properties]], new util.Properties) - - assert(jdbcTable.indexExists("i1") == true) - assert(jdbcTable.indexExists("i2") == true) - - val m = intercept[IndexAlreadyExistsException] { - jdbcTable.createIndex("i1", "", Array(FieldReference("col1")), - Array.empty[util.Map[NamedReference, util.Properties]], properties) - }.getMessage - assert(m.contains("Failed to create index: i1 in new_table")) + var index = jdbcTable.listIndexes() + // The index property size is actually 1. Even though the index is created + // with properties "KEY_BLOCK_SIZE", "10" and "COMMENT", "'this is a comment'", when + // retrieving index using `SHOW INDEXES`, MySQL only returns `COMMENT`. + assert(index(0).properties.size == 1) + assert(index(0).properties.get("COMMENT").equals("this is a comment")) } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index f176726fd0af0..c7c18dab6d660 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.jdbc.v2 +import java.util + import org.apache.log4j.Level -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sample} +import org.apache.spark.sql.{AnalysisException, DataFrame} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample} +import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} +import org.apache.spark.sql.connector.catalog.index.SupportsIndex +import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession @@ -186,6 +190,96 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } + def supportsIndex: Boolean = false + def testIndexProperties(jdbcTable: SupportsIndex): Unit = {} + + test("SPARK-36913: Test INDEX") { + if (supportsIndex) { + withTable(s"$catalogName.new_table") { + sql(s"CREATE TABLE $catalogName.new_table(col1 INT, col2 INT, col3 INT," + + s" col4 INT, col5 INT)") + val loaded = Catalogs.load(catalogName, conf) + val jdbcTable = loaded.asInstanceOf[TableCatalog] + .loadTable(Identifier.of(Array.empty[String], "new_table")) + .asInstanceOf[SupportsIndex] + assert(jdbcTable.indexExists("i1") == false) + assert(jdbcTable.indexExists("i2") == false) + + val properties = new util.Properties(); + val indexType = "DUMMY" + var m = intercept[UnsupportedOperationException] { + jdbcTable.createIndex("i1", indexType, Array(FieldReference("col1")), + Array.empty[util.Map[NamedReference, util.Properties]], properties) + }.getMessage + assert(m.contains(s"Index Type $indexType is not supported." + + s" The supported Index Types are: BTREE and HASH")) + + jdbcTable.createIndex("i1", "BTREE", Array(FieldReference("col1")), + Array.empty[util.Map[NamedReference, util.Properties]], properties) + + jdbcTable.createIndex("i2", "", + Array(FieldReference("col2"), FieldReference("col3"), FieldReference("col5")), + Array.empty[util.Map[NamedReference, util.Properties]], properties) + + assert(jdbcTable.indexExists("i1") == true) + assert(jdbcTable.indexExists("i2") == true) + + m = intercept[IndexAlreadyExistsException] { + jdbcTable.createIndex("i1", "", Array(FieldReference("col1")), + Array.empty[util.Map[NamedReference, util.Properties]], properties) + }.getMessage + assert(m.contains("Failed to create index: i1 in new_table")) + + var index = jdbcTable.listIndexes() + assert(index.length == 2) + + assert(index(0).indexName.equals("i1")) + assert(index(0).indexType.equals("BTREE")) + var cols = index(0).columns + assert(cols.length == 1) + assert(cols(0).describe().equals("col1")) + assert(index(0).properties.size == 0) + + assert(index(1).indexName.equals("i2")) + assert(index(1).indexType.equals("BTREE")) + cols = index(1).columns + assert(cols.length == 3) + assert(cols(0).describe().equals("col2")) + assert(cols(1).describe().equals("col3")) + assert(cols(2).describe().equals("col5")) + assert(index(1).properties.size == 0) + + jdbcTable.dropIndex("i1") + assert(jdbcTable.indexExists("i1") == false) + assert(jdbcTable.indexExists("i2") == true) + + index = jdbcTable.listIndexes() + assert(index.length == 1) + + assert(index(0).indexName.equals("i2")) + assert(index(0).indexType.equals("BTREE")) + cols = index(0).columns + assert(cols.length == 3) + assert(cols(0).describe().equals("col2")) + assert(cols(1).describe().equals("col3")) + assert(cols(2).describe().equals("col5")) + + jdbcTable.dropIndex("i2") + assert(jdbcTable.indexExists("i1") == false) + assert(jdbcTable.indexExists("i2") == false) + index = jdbcTable.listIndexes() + assert(index.length == 0) + + m = intercept[NoSuchIndexException] { + jdbcTable.dropIndex("i2") + }.getMessage + assert(m.contains("Failed to drop index: i2")) + + testIndexProperties(jdbcTable) + } + } + } + def supportsTableSample: Boolean = false private def samplePushed(df: DataFrame): Boolean = { @@ -219,16 +313,6 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu scan.schema.names.sameElements(Seq(col)) } - - def testIndex(tbl: String): Unit = {} - - test("SPARK-36913: Test INDEX") { - withTable(s"$catalogName.new_table") { - sql(s"CREATE TABLE $catalogName.new_table(col1 INT, col2 INT, col3 INT, col4 INT, col5 INT)") - testIndex(s"$catalogName.new_table") - } - } - test("SPARK-37038: Test TABLESAMPLE") { if (supportsTableSample) { withTable(s"$catalogName.new_table") { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java index 24961e460cc26..4181cf5f25118 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java @@ -42,7 +42,7 @@ public interface SupportsIndex extends Table { * @param columns the columns on which index to be created * @param columnsProperties the properties of the columns on which index to be created * @param properties the properties of the index to be created - * @throws IndexAlreadyExistsException If the index already exists (optional) + * @throws IndexAlreadyExistsException If the index already exists. */ void createIndex(String indexName, String indexType, @@ -55,10 +55,9 @@ void createIndex(String indexName, * Drops the index with the given name. * * @param indexName the name of the index to be dropped. - * @return true if the index is dropped - * @throws NoSuchIndexException If the index does not exist (optional) + * @throws NoSuchIndexException If the index does not exist. */ - boolean dropIndex(String indexName) throws NoSuchIndexException; + void dropIndex(String indexName) throws NoSuchIndexException; /** * Checks whether an index exists in this table. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java index 99fce806a11b9..977ed8d6c7528 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java @@ -53,27 +53,25 @@ public TableIndex( /** * @return the Index name. */ - String indexName() { return indexName; } + public String indexName() { return indexName; } /** * @return the indexType of this Index. */ - String indexType() { return indexType; } + public String indexType() { return indexType; } /** * @return the column(s) this Index is on. Could be multi columns (a multi-column index). */ - NamedReference[] columns() { return columns; } + public NamedReference[] columns() { return columns; } /** * @return the map of column and column property map. */ - Map columnProperties() { return columnProperties; } + public Map columnProperties() { return columnProperties; } /** * Returns the index properties. */ - Properties properties() { - return properties; - } + public Properties properties() { return properties; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 7a9f7b5c6bced..8b0710b2c1f19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -96,5 +96,5 @@ class NoSuchPartitionsException(message: String) extends AnalysisException(messa class NoSuchTempFunctionException(func: String) extends AnalysisException(s"Temporary function '$func' not found") -class NoSuchIndexException(indexName: String) - extends AnalysisException(s"Index '$indexName' not found") +class NoSuchIndexException(message: String, cause: Option[Throwable] = None) + extends AnalysisException(message, cause = cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 31c11568e35d5..2e21571939cf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider @@ -1050,6 +1051,29 @@ object JdbcUtils extends Logging { dialect.indexExists(conn, indexName, tableName, options) } + /** + * Drop an index. + */ + def dropIndex( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) + executeStatement(conn, options, dialect.dropIndex(indexName, tableName)) + } + + /** + * List all the indexes in a table. + */ + def listIndexes( + conn: Connection, + tableName: String, + options: JDBCOptions): Array[TableIndex] = { + val dialect = JdbcDialects.get(options.url) + dialect.listIndexes(conn, tableName, options) + } + private def executeStatement(conn: Connection, options: JDBCOptions, sql: String): Unit = { val statement = conn.createStatement try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala index 957d021963a7f..ba56643f4d980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala @@ -73,11 +73,18 @@ case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOpt } } - override def dropIndex(indexName: String): Boolean = { - throw new UnsupportedOperationException("dropIndex is not supported yet") + override def dropIndex(indexName: String): Unit = { + JdbcUtils.withConnection(jdbcOptions) { conn => + JdbcUtils.classifyException(s"Failed to drop index: $indexName", + JdbcDialects.get(jdbcOptions.url)) { + JdbcUtils.dropIndex(conn, indexName, name, jdbcOptions) + } + } } override def listIndexes(): Array[TableIndex] = { - throw new UnsupportedOperationException("listIndexes is not supported yet") + JdbcUtils.withConnection(jdbcOptions) { conn => + JdbcUtils.listIndexes(conn, name, jdbcOptions) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index dcf9d0f0cfa52..dbf5e4c037d31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ +import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -327,7 +328,7 @@ abstract class JdbcDialect extends Serializable with Logging{ } /** - * Creates an index. + * Build a create index SQL statement. * * @param indexName the name of the index to be created * @param indexType the type of the index to be created @@ -335,6 +336,7 @@ abstract class JdbcDialect extends Serializable with Logging{ * @param columns the columns on which index to be created * @param columnsProperties the properties of the columns on which index to be created * @param properties the properties of the index to be created + * @return the SQL statement to use for creating the index. */ def createIndex( indexName: String, @@ -363,6 +365,27 @@ abstract class JdbcDialect extends Serializable with Logging{ throw new UnsupportedOperationException("indexExists is not supported") } + /** + * Build a drop index SQL statement. + * + * @param indexName the name of the index to be dropped. + * @param tableName the table name on which index to be dropped. + * @return the SQL statement to use for dropping the index. + */ + def dropIndex(indexName: String, tableName: String): String = { + throw new UnsupportedOperationException("dropIndex is not supported") + } + + /** + * Lists all the indexes in this table. + */ + def listIndexes( + conn: Connection, + tableName: String, + options: JDBCOptions): Array[TableIndex] = { + throw new UnsupportedOperationException("listIndexes is not supported") + } + /** * Gets a dialect exception, classifies it and wraps it by `AnalysisException`. * @param message The error message to be placed to the returned exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 5c16ef6a947ba..7e85b3bbb84e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -25,8 +25,9 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.analysis.IndexAlreadyExistsException -import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} +import org.apache.spark.sql.connector.catalog.index.TableIndex +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} @@ -127,10 +128,19 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { indexProperties = indexProperties + " " + s"$k $v" } } - + val iType = if (indexType.isEmpty) { + "" + } else { + if (indexType.length > 1 && !indexType.equalsIgnoreCase("BTREE") && + !indexType.equalsIgnoreCase("HASH")) { + throw new UnsupportedOperationException(s"Index Type $indexType is not supported." + + " The supported Index Types are: BTREE and HASH") + } + s"USING $indexType" + } // columnsProperties doesn't apply to MySQL so it is ignored - s"CREATE $indexType INDEX ${quoteIdentifier(indexName)} ON" + - s" ${quoteIdentifier(tableName)}" + s" (${columnList.mkString(", ")}) $indexProperties" + s"CREATE INDEX ${quoteIdentifier(indexName)} $iType ON" + + s" ${quoteIdentifier(tableName)} (${columnList.mkString(", ")}) $indexProperties" } // SHOW INDEX syntax @@ -157,17 +167,61 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { } } - override def classifyException(message: String, e: Throwable): AnalysisException = { - if (e.isInstanceOf[SQLException]) { - // Error codes are from - // https://mariadb.com/kb/en/mariadb-error-codes/#shared-mariadbmysql-error-codes - e.asInstanceOf[SQLException].getErrorCode match { - // ER_DUP_KEYNAME - case 1061 => - throw new IndexAlreadyExistsException(message, cause = Some(e)) - case _ => + override def dropIndex(indexName: String, tableName: String): String = { + s"DROP INDEX ${quoteIdentifier(indexName)} ON $tableName" + } + + // SHOW INDEX syntax + // https://dev.mysql.com/doc/refman/8.0/en/show-index.html + override def listIndexes( + conn: Connection, + tableName: String, + options: JDBCOptions): Array[TableIndex] = { + val sql = s"SHOW INDEXES FROM $tableName" + var indexMap: Map[String, TableIndex] = Map() + try { + val rs = JdbcUtils.executeQuery(conn, options, sql) + while (rs.next()) { + val indexName = rs.getString("key_name") + val colName = rs.getString("column_name") + val indexType = rs.getString("index_type") + val indexComment = rs.getString("Index_comment") + if (indexMap.contains(indexName)) { + val index = indexMap.get(indexName).get + val newIndex = new TableIndex(indexName, indexType, + index.columns() :+ FieldReference(colName), + index.columnProperties, index.properties) + indexMap += (indexName -> newIndex) + } else { + // The only property we are building here is `COMMENT` because it's the only one + // we can get from `SHOW INDEXES`. + val properties = new util.Properties(); + if (indexComment.nonEmpty) properties.put("COMMENT", indexComment) + val index = new TableIndex(indexName, indexType, Array(FieldReference(colName)), + new util.HashMap[NamedReference, util.Properties](), properties) + indexMap += (indexName -> index) + } } + } catch { + case _: Exception => + logWarning("Cannot retrieved index info.") + } + indexMap.values.toArray + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getErrorCode match { + // ER_DUP_KEYNAME + case 1061 => + throw new IndexAlreadyExistsException(message, cause = Some(e)) + case 1091 => + throw new NoSuchIndexException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case unsupported: UnsupportedOperationException => throw unsupported + case _ => super.classifyException(message, e) } - super.classifyException(message, e) } } From ac8fd9c875f4e0cb010b2bf3cdbdd107afd3a8c5 Mon Sep 17 00:00:00 2001 From: dch nguyen Date: Fri, 17 Dec 2021 20:29:57 +0800 Subject: [PATCH 20/53] [SPARK-37343][SQL] Implement createIndex, IndexExists and dropIndex in JDBC (Postgres dialect) ### What changes were proposed in this pull request? Implementing `createIndex`/`IndexExists`/`dropIndex` in DS V2 JDBC for Postgres dialect. ### Why are the changes needed? This is a subtask of the V2 Index support. This PR implements `createIndex`, `IndexExists` and `dropIndex`. After review for some changes in this PR, I will create new PR for `listIndexs`, or add it in this PR. This PR only implements `createIndex`, `IndexExists` and `dropIndex` in Postgres dialect. ### Does this PR introduce _any_ user-facing change? Yes, `createIndex`/`IndexExists`/`dropIndex` in DS V2 JDBC ### How was this patch tested? New test. Closes #34673 from dchvn/Dsv2_index_postgres. Authored-by: dch nguyen Signed-off-by: Wenchen Fan --- .../sql/jdbc/v2/MySQLIntegrationSuite.scala | 26 +----- .../jdbc/v2/PostgresIntegrationSuite.scala | 4 + .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 88 ++++++------------- .../datasources/jdbc/JdbcUtils.scala | 71 ++++++++++++++- .../apache/spark/sql/jdbc/MySQLDialect.scala | 44 ++-------- .../spark/sql/jdbc/PostgresDialect.scala | 61 ++++++++++++- 6 files changed, 170 insertions(+), 124 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 67e81087f8fb0..71adc51b87441 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -18,14 +18,11 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.{Connection, SQLFeatureNotSupportedException} -import java.util import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.index.SupportsIndex -import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} import org.apache.spark.sql.types._ @@ -33,9 +30,9 @@ import org.apache.spark.tags.DockerTest /** * - * To run this test suite for a specific version (e.g., mysql:5.7.31): + * To run this test suite for a specific version (e.g., mysql:5.7.36): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.31 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLIntegrationSuite" * * }}} @@ -45,7 +42,7 @@ import org.apache.spark.tags.DockerTest class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override val catalogName: String = "mysql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.31") + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) @@ -121,20 +118,5 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def supportsIndex: Boolean = true - override def testIndexProperties(jdbcTable: SupportsIndex): Unit = { - val properties = new util.Properties(); - properties.put("KEY_BLOCK_SIZE", "10") - properties.put("COMMENT", "'this is a comment'") - // MySQL doesn't allow property set on individual column, so use empty Array for - // column properties - jdbcTable.createIndex("i1", "BTREE", Array(FieldReference("col1")), - Array.empty[util.Map[NamedReference, util.Properties]], properties) - - var index = jdbcTable.listIndexes() - // The index property size is actually 1. Even though the index is created - // with properties "KEY_BLOCK_SIZE", "10" and "COMMENT", "'this is a comment'", when - // retrieving index using `SHOW INDEXES`, MySQL only returns `COMMENT`. - assert(index(0).properties.size == 1) - assert(index(0).properties.get("COMMENT").equals("this is a comment")) - } + override def indexOptions: String = "KEY_BLOCK_SIZE=10" } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 3ccf051fea52b..1b16b817e7d98 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -80,4 +80,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes } override def supportsTableSample: Boolean = true + + override def supportsIndex: Boolean = true + + override def indexOptions: String = "FILLFACTOR=70" } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index c7c18dab6d660..d26d5ae15e5ce 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.jdbc.v2 -import java.util - import org.apache.log4j.Level import org.apache.spark.sql.{AnalysisException, DataFrame} +import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample} import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex @@ -191,13 +190,14 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } def supportsIndex: Boolean = false - def testIndexProperties(jdbcTable: SupportsIndex): Unit = {} - test("SPARK-36913: Test INDEX") { + def indexOptions: String = "" + + test("SPARK-36895: Test INDEX Using SQL") { if (supportsIndex) { withTable(s"$catalogName.new_table") { sql(s"CREATE TABLE $catalogName.new_table(col1 INT, col2 INT, col3 INT," + - s" col4 INT, col5 INT)") + " col4 INT, col5 INT)") val loaded = Catalogs.load(catalogName, conf) val jdbcTable = loaded.asInstanceOf[TableCatalog] .loadTable(Identifier.of(Array.empty[String], "new_table")) @@ -205,77 +205,41 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(jdbcTable.indexExists("i1") == false) assert(jdbcTable.indexExists("i2") == false) - val properties = new util.Properties(); val indexType = "DUMMY" var m = intercept[UnsupportedOperationException] { - jdbcTable.createIndex("i1", indexType, Array(FieldReference("col1")), - Array.empty[util.Map[NamedReference, util.Properties]], properties) + sql(s"CREATE index i1 ON $catalogName.new_table USING $indexType (col1)") }.getMessage assert(m.contains(s"Index Type $indexType is not supported." + - s" The supported Index Types are: BTREE and HASH")) - - jdbcTable.createIndex("i1", "BTREE", Array(FieldReference("col1")), - Array.empty[util.Map[NamedReference, util.Properties]], properties) + s" The supported Index Types are:")) - jdbcTable.createIndex("i2", "", - Array(FieldReference("col2"), FieldReference("col3"), FieldReference("col5")), - Array.empty[util.Map[NamedReference, util.Properties]], properties) + sql(s"CREATE index i1 ON $catalogName.new_table USING BTREE (col1)") + sql(s"CREATE index i2 ON $catalogName.new_table (col2, col3, col5)" + + s" OPTIONS ($indexOptions)") assert(jdbcTable.indexExists("i1") == true) assert(jdbcTable.indexExists("i2") == true) + // This should pass without exception + sql(s"CREATE index IF NOT EXISTS i1 ON $catalogName.new_table (col1)") + m = intercept[IndexAlreadyExistsException] { - jdbcTable.createIndex("i1", "", Array(FieldReference("col1")), - Array.empty[util.Map[NamedReference, util.Properties]], properties) + sql(s"CREATE index i1 ON $catalogName.new_table (col1)") }.getMessage - assert(m.contains("Failed to create index: i1 in new_table")) - - var index = jdbcTable.listIndexes() - assert(index.length == 2) - - assert(index(0).indexName.equals("i1")) - assert(index(0).indexType.equals("BTREE")) - var cols = index(0).columns - assert(cols.length == 1) - assert(cols(0).describe().equals("col1")) - assert(index(0).properties.size == 0) - - assert(index(1).indexName.equals("i2")) - assert(index(1).indexType.equals("BTREE")) - cols = index(1).columns - assert(cols.length == 3) - assert(cols(0).describe().equals("col2")) - assert(cols(1).describe().equals("col3")) - assert(cols(2).describe().equals("col5")) - assert(index(1).properties.size == 0) - - jdbcTable.dropIndex("i1") - assert(jdbcTable.indexExists("i1") == false) - assert(jdbcTable.indexExists("i2") == true) - - index = jdbcTable.listIndexes() - assert(index.length == 1) + assert(m.contains("Failed to create index i1 in new_table")) - assert(index(0).indexName.equals("i2")) - assert(index(0).indexType.equals("BTREE")) - cols = index(0).columns - assert(cols.length == 3) - assert(cols(0).describe().equals("col2")) - assert(cols(1).describe().equals("col3")) - assert(cols(2).describe().equals("col5")) + sql(s"DROP index i1 ON $catalogName.new_table") + sql(s"DROP index i2 ON $catalogName.new_table") - jdbcTable.dropIndex("i2") assert(jdbcTable.indexExists("i1") == false) assert(jdbcTable.indexExists("i2") == false) - index = jdbcTable.listIndexes() - assert(index.length == 0) + + // This should pass without exception + sql(s"DROP index IF EXISTS i1 ON $catalogName.new_table") m = intercept[NoSuchIndexException] { - jdbcTable.dropIndex("i2") + sql(s"DROP index i1 ON $catalogName.new_table") }.getMessage - assert(m.contains("Failed to drop index: i2")) - - testIndexProperties(jdbcTable) + assert(m.contains("Failed to drop index i1 in new_table")) } } } @@ -338,7 +302,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(samplePushed(df3)) assert(limitPushed(df3, 2)) assert(columnPruned(df3, "col1")) - assert(df3.collect().length == 2) + assert(df3.collect().length <= 2) // sample(... PERCENT) push down + limit push down + column pruning val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" + @@ -346,7 +310,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(samplePushed(df4)) assert(limitPushed(df4, 2)) assert(columnPruned(df4, "col1")) - assert(df4.collect().length == 2) + assert(df4.collect().length <= 2) // sample push down + filter push down + limit push down val df5 = sql(s"SELECT * FROM $catalogName.new_table" + @@ -354,7 +318,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(samplePushed(df5)) assert(filterPushed(df5)) assert(limitPushed(df5, 2)) - assert(df5.collect().length == 2) + assert(df5.collect().length <= 2) // sample + filter + limit + column pruning // sample pushed down, filer/limit not pushed down, column pruned @@ -365,7 +329,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(!filterPushed(df6)) assert(!limitPushed(df6, 2)) assert(columnPruned(df6, "col1")) - assert(df6.collect().length == 2) + assert(df6.collect().length <= 2) // sample + limit // Push down order is sample -> filter -> limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 2e21571939cf8..3550568483a0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -23,6 +23,8 @@ import java.util import java.util.Locale import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.NonFatal @@ -38,7 +40,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.TableChange -import org.apache.spark.sql.connector.catalog.index.TableIndex +import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider @@ -1084,6 +1086,73 @@ object JdbcUtils extends Logging { } } + /** + * Check if index exists in a table + */ + def checkIfIndexExists( + conn: Connection, + sql: String, + options: JDBCOptions): Boolean = { + val statement = conn.createStatement + try { + statement.setQueryTimeout(options.queryTimeout) + val rs = statement.executeQuery(sql) + rs.next + } catch { + case _: Exception => + logWarning("Cannot retrieved index info.") + false + } finally { + statement.close() + } + } + + /** + * Process index properties and return tuple of indexType and list of the other index properties. + */ + def processIndexProperties( + properties: util.Map[String, String], + catalogName: String): (String, Array[String]) = { + var indexType = "" + val indexPropertyList: ArrayBuffer[String] = ArrayBuffer[String]() + val supportedIndexTypeList = getSupportedIndexTypeList(catalogName) + + if (!properties.isEmpty) { + properties.asScala.foreach { case (k, v) => + if (k.equals(SupportsIndex.PROP_TYPE)) { + if (containsIndexTypeIgnoreCase(supportedIndexTypeList, v)) { + indexType = s"USING $v" + } else { + throw new UnsupportedOperationException(s"Index Type $v is not supported." + + s" The supported Index Types are: ${supportedIndexTypeList.mkString(" AND ")}") + } + } else { + indexPropertyList.append(s"$k = $v") + } + } + } + (indexType, indexPropertyList.toArray) + } + + def containsIndexTypeIgnoreCase(supportedIndexTypeList: Array[String], value: String): Boolean = { + if (supportedIndexTypeList.isEmpty) { + throw new UnsupportedOperationException( + "Cannot specify 'USING index_type' in 'CREATE INDEX'") + } + for (indexType <- supportedIndexTypeList) { + if (value.equalsIgnoreCase(indexType)) return true + } + false + } + + def getSupportedIndexTypeList(catalogName: String): Array[String] = { + catalogName match { + case "mysql" => Array("BTREE", "HASH") + case "postgresql" => Array("BTREE", "HASH", "BRIN") + case _ => Array.empty + } + } + def executeQuery(conn: Connection, options: JDBCOptions, sql: String): ResultSet = { val statement = conn.createStatement try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 7e85b3bbb84e8..3fa5481816af0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -21,8 +21,6 @@ import java.sql.{Connection, SQLException, Types} import java.util import java.util.Locale -import scala.collection.JavaConverters._ - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} @@ -115,32 +113,17 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { // https://dev.mysql.com/doc/refman/8.0/en/create-index.html override def createIndex( indexName: String, - indexType: String, tableName: String, columns: Array[NamedReference], columnsProperties: Array[util.Map[NamedReference, util.Properties]], properties: util.Properties): String = { val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head)) - var indexProperties: String = "" - val scalaProps = properties.asScala - if (!properties.isEmpty) { - scalaProps.foreach { case (k, v) => - indexProperties = indexProperties + " " + s"$k $v" - } - } - val iType = if (indexType.isEmpty) { - "" - } else { - if (indexType.length > 1 && !indexType.equalsIgnoreCase("BTREE") && - !indexType.equalsIgnoreCase("HASH")) { - throw new UnsupportedOperationException(s"Index Type $indexType is not supported." + - " The supported Index Types are: BTREE and HASH") - } - s"USING $indexType" - } + val (indexType, indexPropertyList) = JdbcUtils.processIndexProperties(properties, "mysql") + // columnsProperties doesn't apply to MySQL so it is ignored - s"CREATE INDEX ${quoteIdentifier(indexName)} $iType ON" + - s" ${quoteIdentifier(tableName)} (${columnList.mkString(", ")}) $indexProperties" + s"CREATE INDEX ${quoteIdentifier(indexName)} $indexType ON" + + s" ${quoteIdentifier(tableName)} (${columnList.mkString(", ")})" + + s" ${indexPropertyList.mkString(" ")}" } // SHOW INDEX syntax @@ -150,21 +133,8 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { indexName: String, tableName: String, options: JDBCOptions): Boolean = { - val sql = s"SHOW INDEXES FROM ${quoteIdentifier(tableName)}" - try { - val rs = JdbcUtils.executeQuery(conn, options, sql) - while (rs.next()) { - val retrievedIndexName = rs.getString("key_name") - if (conf.resolver(retrievedIndexName, indexName)) { - return true - } - } - false - } catch { - case _: Exception => - logWarning("Cannot retrieved index info.") - false - } + val sql = s"SHOW INDEXES FROM ${quoteIdentifier(tableName)} WHERE key_name = '$indexName'" + JdbcUtils.checkIfIndexExists(conn, sql, options) } override def dropIndex(indexName: String, tableName: String): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 317ae19ed914b..356cb4ddbd008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,15 +17,20 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, Types} +import java.sql.{Connection, SQLException, Types} +import java.util import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ -private object PostgresDialect extends JdbcDialect { +private object PostgresDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") @@ -164,4 +169,56 @@ private object PostgresDialect extends JdbcDialect { s"TABLESAMPLE BERNOULLI" + s" (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE (${sample.seed})" } + + // CREATE INDEX syntax + // https://www.postgresql.org/docs/14/sql-createindex.html + override def createIndex( + indexName: String, + tableName: String, + columns: Array[NamedReference], + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String]): String = { + val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head)) + var indexProperties = "" + val (indexType, indexPropertyList) = JdbcUtils.processIndexProperties(properties, "postgresql") + + if (indexPropertyList.nonEmpty) { + indexProperties = "WITH (" + indexPropertyList.mkString(", ") + ")" + } + + s"CREATE INDEX ${quoteIdentifier(indexName)} ON ${quoteIdentifier(tableName)}" + + s" $indexType (${columnList.mkString(", ")}) $indexProperties" + } + + // SHOW INDEX syntax + // https://www.postgresql.org/docs/14/view-pg-indexes.html + override def indexExists( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Boolean = { + val sql = s"SELECT * FROM pg_indexes WHERE tablename = '$tableName' AND" + + s" indexname = '$indexName'" + JdbcUtils.checkIfIndexExists(conn, sql, options) + } + + // DROP INDEX syntax + // https://www.postgresql.org/docs/14/sql-dropindex.html + override def dropIndex(indexName: String, tableName: String): String = { + s"DROP INDEX ${quoteIdentifier(indexName)}" + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getSQLState match { + // https://www.postgresql.org/docs/14/errcodes-appendix.html + case "42P07" => throw new IndexAlreadyExistsException(message, cause = Some(e)) + case "42704" => throw new NoSuchIndexException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case unsupported: UnsupportedOperationException => throw unsupported + case _ => super.classifyException(message, e) + } + } } From 3cac7e66b9514cab7005ed283e8a9be67ea101a7 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 25 Jan 2022 18:50:33 +0800 Subject: [PATCH 21/53] [SPARK-37867][SQL] Compile aggregate functions of build-in JDBC dialect ### What changes were proposed in this pull request? DS V2 translate a lot of standard aggregate functions. Currently, only H2Dialect compile these standard aggregate functions. This PR compile these standard aggregate functions for other build-in JDBC dialect. ### Why are the changes needed? Make build-in JDBC dialect support complete aggregate push-down. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users could use complete aggregate push-down with build-in JDBC dialect. ### How was this patch tested? New tests. Closes #35166 from beliefer/SPARK-37867. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../sql/jdbc/v2/DB2IntegrationSuite.scala | 17 +- .../v2/DockerJDBCIntegrationV2Suite.scala | 44 ++++ .../jdbc/v2/MsSqlServerIntegrationSuite.scala | 16 +- .../sql/jdbc/v2/MySQLIntegrationSuite.scala | 17 +- .../sql/jdbc/v2/OracleIntegrationSuite.scala | 24 ++- .../jdbc/v2/PostgresIntegrationSuite.scala | 19 +- .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 192 ++++++++++++++++-- .../apache/spark/sql/jdbc/DB2Dialect.scala | 13 ++ .../apache/spark/sql/jdbc/DerbyDialect.scala | 25 +++ .../spark/sql/jdbc/MsSqlServerDialect.scala | 25 +++ .../apache/spark/sql/jdbc/MySQLDialect.scala | 25 +++ .../apache/spark/sql/jdbc/OracleDialect.scala | 37 ++++ .../spark/sql/jdbc/PostgresDialect.scala | 37 ++++ .../spark/sql/jdbc/TeradataDialect.scala | 37 ++++ 14 files changed, 490 insertions(+), 38 deletions(-) create mode 100644 external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index cb0dd1e37e9ff..d0479e9032e06 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection +import java.util.Locale import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -36,8 +37,9 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "db2" + override val namespaceOpt: Option[String] = Some("DB2INST1") override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", "ibmcom/db2:11.5.4.0") override val env = Map( @@ -59,8 +61,13 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.db2.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.db2.pushDownAggregate", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)") + .executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -86,4 +93,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + testVarPop() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala new file mode 100644 index 0000000000000..72edfc9f1bf1c --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import org.apache.spark.sql.jdbc.DockerJDBCIntegrationSuite + +abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite { + + /** + * Prepare databases and tables for testing. + */ + override def dataPreparation(connection: Connection): Unit = { + tablePreparation(connection) + connection.prepareStatement("INSERT INTO employee VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() + } + + def tablePreparation(connection: Connection): Unit +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index b9f5b774a5347..536eb465ceb11 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -36,7 +36,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mssql" @@ -57,10 +57,15 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBC override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mssql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.mssql.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)") + .executeUpdate() + } override def notSupportsTableComment: Boolean = true @@ -90,4 +95,9 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBC assert(msg.contains("UpdateColumnNullability is not supported")) } + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 71adc51b87441..bc4bf54324ee5 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest * */ @DockerTest -class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mysql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") @@ -57,13 +57,17 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mysql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mysql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.mysql.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) private var mySQLVersion = -1 - override def dataPreparation(conn: Connection): Unit = { - mySQLVersion = conn.getMetaData.getDatabaseMajorVersion + override def tablePreparation(connection: Connection): Unit = { + mySQLVersion = connection.getMetaData.getDatabaseMajorVersion + connection.prepareStatement( + "CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," + + " bonus DOUBLE)").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -119,4 +123,9 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def supportsIndex: Boolean = true override def indexOptions: String = "KEY_BLOCK_SIZE=10" + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 45d793aaa743e..b38f2675243e6 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection +import java.util.Locale import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -53,8 +54,9 @@ import org.apache.spark.tags.DockerTest * It has been validated with 18.4.0 Express Edition. */ @DockerTest -class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "oracle" + override val namespaceOpt: Option[String] = Some("SYSTEM") override val db = new DatabaseOnDocker { lazy override val imageName = sys.env("ORACLE_DOCKER_IMAGE_NAME") override val env = Map( @@ -69,9 +71,15 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.oracle.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) - override def dataPreparation(conn: Connection): Unit = {} + + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + + " bonus BINARY_DOUBLE)").executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -89,4 +97,14 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest assert(msg1.contains( s"Cannot update $catalogName.alt_table field ID: string cannot be cast to int")) } + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() + testCovarPop() + testCovarSamp() + testCorr() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 1b16b817e7d98..b3004e1c21c89 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -22,7 +22,7 @@ import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -34,7 +34,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:13.0-alpine") @@ -51,8 +51,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.postgresql.pushDownTableSample", "true") .set("spark.sql.catalog.postgresql.pushDownLimit", "true") + .set("spark.sql.catalog.postgresql.pushDownAggregate", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + + " bonus double precision)").executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -84,4 +89,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes override def supportsIndex: Boolean = true override def indexOptions: String = "FILLFACTOR=70" + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() + testCovarPop() + testCovarSamp() + testCorr() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index d26d5ae15e5ce..667579b20eaf7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -36,6 +36,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu import testImplicits._ val catalogName: String + + val namespaceOpt: Option[String] = None + + private def catalogAndNamespace = + namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) + // dialect specific update column type test def testUpdateColumnType(tbl: String): Unit @@ -246,22 +252,30 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def supportsTableSample: Boolean = false - private def samplePushed(df: DataFrame): Boolean = { + private def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { val sample = df.queryExecution.optimizedPlan.collect { case s: Sample => s } - sample.isEmpty + if (pushed) { + assert(sample.isEmpty) + } else { + assert(sample.nonEmpty) + } } - private def filterPushed(df: DataFrame): Boolean = { + private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { val filter = df.queryExecution.optimizedPlan.collect { case f: Filter => f } - filter.isEmpty + if (pushed) { + assert(filter.isEmpty) + } else { + assert(filter.nonEmpty) + } } private def limitPushed(df: DataFrame, limit: Int): Boolean = { - val filter = df.queryExecution.optimizedPlan.collect { + df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => return v1.pushedDownOperators.limit == Some(limit) @@ -270,11 +284,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu false } - private def columnPruned(df: DataFrame, col: String): Boolean = { + private def checkColumnPruned(df: DataFrame, col: String): Unit = { val scan = df.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get - scan.schema.names.sameElements(Seq(col)) + assert(scan.schema.names.sameElements(Seq(col))) } test("SPARK-37038: Test TABLESAMPLE") { @@ -286,37 +300,37 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu // sample push down + column pruning val df1 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + " REPEATABLE (12345)") - assert(samplePushed(df1)) - assert(columnPruned(df1, "col1")) + checkSamplePushed(df1) + checkColumnPruned(df1, "col1") assert(df1.collect().length < 10) // sample push down only val df2 = sql(s"SELECT * FROM $catalogName.new_table TABLESAMPLE (50 PERCENT)" + " REPEATABLE (12345)") - assert(samplePushed(df2)) + checkSamplePushed(df2) assert(df2.collect().length < 10) // sample(BUCKET ... OUT OF) push down + limit push down + column pruning val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + " LIMIT 2") - assert(samplePushed(df3)) + checkSamplePushed(df3) assert(limitPushed(df3, 2)) - assert(columnPruned(df3, "col1")) + checkColumnPruned(df3, "col1") assert(df3.collect().length <= 2) // sample(... PERCENT) push down + limit push down + column pruning val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" + " TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2") - assert(samplePushed(df4)) + checkSamplePushed(df4) assert(limitPushed(df4, 2)) - assert(columnPruned(df4, "col1")) + checkColumnPruned(df4, "col1") assert(df4.collect().length <= 2) // sample push down + filter push down + limit push down val df5 = sql(s"SELECT * FROM $catalogName.new_table" + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") - assert(samplePushed(df5)) - assert(filterPushed(df5)) + checkSamplePushed(df5) + checkFilterPushed(df5) assert(limitPushed(df5, 2)) assert(df5.collect().length <= 2) @@ -325,27 +339,161 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu // Todo: push down filter/limit val df6 = sql(s"SELECT col1 FROM $catalogName.new_table" + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") - assert(samplePushed(df6)) - assert(!filterPushed(df6)) + checkSamplePushed(df6) + checkFilterPushed(df6, false) assert(!limitPushed(df6, 2)) - assert(columnPruned(df6, "col1")) + checkColumnPruned(df6, "col1") assert(df6.collect().length <= 2) // sample + limit // Push down order is sample -> filter -> limit // only limit is pushed down because in this test sample is after limit val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5) - assert(!samplePushed(df7)) + checkSamplePushed(df7, false) assert(limitPushed(df7, 2)) // sample + filter // Push down order is sample -> filter -> limit // only filter is pushed down because in this test sample is after filter val df8 = spark.read.table(s"$catalogName.new_table").where($"col1" > 1).sample(0.5) - assert(!samplePushed(df8)) - assert(filterPushed(df8)) + checkSamplePushed(df8, false) + checkFilterPushed(df8) assert(df8.collect().length < 10) } } } + + protected def checkAggregateRemoved(df: DataFrame): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + assert(aggregates.isEmpty) + } + + private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions.length == 1) + assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc]) + assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName) + } + } + + protected def caseConvert(tableName: String): String = tableName + + protected def testVarPop(): Unit = { + test(s"scan with aggregate push-down: VAR_POP") { + val df = sql(s"SELECT VAR_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "VAR_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 10000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testVarSamp(): Unit = { + test(s"scan with aggregate push-down: VAR_SAMP") { + val df = sql( + s"SELECT VAR_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "VAR_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000d) + assert(row(1).getDouble(0) === 5000d) + assert(row(2).isNullAt(0)) + } + } + + protected def testStddevPop(): Unit = { + test("scan with aggregate push-down: STDDEV_POP") { + val df = sql( + s"SELECT STDDEV_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "STDDEV_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 100d) + assert(row(1).getDouble(0) === 50d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testStddevSamp(): Unit = { + test("scan with aggregate push-down: STDDEV_SAMP") { + val df = sql( + s"SELECT STDDEV_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "STDDEV_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 141.4213562373095d) + assert(row(1).getDouble(0) === 70.71067811865476d) + assert(row(2).isNullAt(0)) + } + } + + protected def testCovarPop(): Unit = { + test("scan with aggregate push-down: COVAR_POP") { + val df = sql( + s"SELECT COVAR_POP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "COVAR_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 10000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testCovarSamp(): Unit = { + test("scan with aggregate push-down: COVAR_SAMP") { + val df = sql( + s"SELECT COVAR_SAMP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "COVAR_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000d) + assert(row(1).getDouble(0) === 5000d) + assert(row(2).isNullAt(0)) + } + } + + protected def testCorr(): Unit = { + test("scan with aggregate push-down: CORR") { + val df = sql( + s"SELECT CORR(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "CORR") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 1d) + assert(row(1).getDouble(0) === 1d) + assert(row(2).isNullAt(0)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 0b394db5c8932..9e9aac679ab39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { @@ -27,6 +28,18 @@ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARIANCE($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index f19ef7ead5f8e..e87d4d08ae031 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -29,6 +30,30 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) Option(FloatType) else None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8e5674a181e7a..442c5599b3ab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -36,6 +37,30 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDEVP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDEV($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (typeName.contains("datetimeoffset")) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 3fa5481816af0..8316b3c04e107 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} @@ -35,6 +36,30 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index b741ece8dda9b..4fe7d93142c1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp, Types} import java.util.{Locale, TimeZone} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -33,6 +34,42 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + private def supportTimeZoneTypes: Boolean = { val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) // TODO: support timezone types when users are not using the JVM timezone, which diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 356cb4ddbd008..3b1a2c81fffd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -35,6 +36,42 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 13f4c5fe9c926..6344667b3180e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ @@ -27,6 +28,42 @@ private case object TeradataDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) From b5111bc24a787bb78eb78a4f051a46b1a171c1fa Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 26 Jan 2022 13:29:23 +0800 Subject: [PATCH 22/53] [SPARK-37929][SQL][FOLLOWUP] Support cascade mode for JDBC V2 ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/35246 support `cascade` mode for dropNamespace API. This PR followup https://github.com/apache/spark/pull/35246 to make JDBC V2 respect `cascade`. ### Why are the changes needed? Let JDBC V2 respect `cascade`. ### Does this PR introduce _any_ user-facing change? Yes. Users could manipulate `drop namespace` with `cascade` on JDBC V2. ### How was this patch tested? New tests. Closes #35271 from beliefer/SPARK-37929-followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../sql/jdbc/v2/V2JDBCNamespaceTest.scala | 34 ++++++++++++++++++- .../datasources/jdbc/JdbcUtils.scala | 10 ++++-- .../v2/jdbc/JDBCTableCatalog.scala | 5 +-- .../spark/sql/jdbc/PostgresDialect.scala | 3 +- 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index 95d59fec2fac6..4baed69b79f22 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -17,21 +17,31 @@ package org.apache.spark.sql.jdbc.v2 +import java.util +import java.util.Collections + import scala.collection.JavaConverters._ import org.apache.log4j.Level import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.NamespaceChange +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException +import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.tags.DockerTest @DockerTest private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerIntegrationFunSuite { val catalog = new JDBCTableCatalog() + private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val schema: StructType = new StructType() + .add("id", IntegerType) + .add("data", StringType) + def builtinNamespaces: Array[Array[String]] test("listNamespaces: basic behavior") { @@ -60,4 +70,26 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte }.getMessage assert(msg.contains("Namespace 'foo' not found")) } + + test("Drop namespace") { + val ident1 = Identifier.of(Array("foo"), "tab") + // Drop empty namespace without cascade + catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.dropNamespace(Array("foo"), cascade = false) + assert(catalog.namespaceExists(Array("foo")) === false) + + // Drop non empty namespace without cascade + catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.createTable(ident1, schema, Array.empty, emptyProps) + intercept[NonEmptyNamespaceException] { + catalog.dropNamespace(Array("foo"), cascade = false) + } + + // Drop non empty namespace with cascade + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.dropNamespace(Array("foo"), cascade = true) + assert(catalog.namespaceExists(Array("foo")) === false) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 3550568483a0c..5d62b96c9ce53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -1019,9 +1019,15 @@ object JdbcUtils extends Logging { /** * Drops a namespace from the JDBC database. */ - def dropNamespace(conn: Connection, options: JDBCOptions, namespace: String): Unit = { + def dropNamespace( + conn: Connection, options: JDBCOptions, namespace: String, cascade: Boolean): Unit = { val dialect = JdbcDialects.get(options.url) - executeStatement(conn, options, s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)}") + val dropCmd = if (cascade) { + s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)} CASCADE" + } else { + s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)}" + } + executeStatement(conn, options, dropCmd) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index 566706486d3f0..58ad3f98d120b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -280,12 +280,9 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def dropNamespace(namespace: Array[String]): Boolean = namespace match { case Array(db) if namespaceExists(namespace) => - if (listTables(Array(db)).nonEmpty) { - throw QueryExecutionErrors.namespaceNotEmptyError(namespace) - } JdbcUtils.withConnection(options) { conn => JdbcUtils.classifyException(s"Failed drop name space: $db", dialect) { - JdbcUtils.dropNamespace(conn, options, db) + JdbcUtils.dropNamespace(conn, options, db, cascade) true } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 3b1a2c81fffd6..46e79404f3e54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -23,7 +23,7 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} +import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} @@ -252,6 +252,7 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { // https://www.postgresql.org/docs/14/errcodes-appendix.html case "42P07" => throw new IndexAlreadyExistsException(message, cause = Some(e)) case "42704" => throw new NoSuchIndexException(message, cause = Some(e)) + case "2BP01" => throw NonEmptyNamespaceException(message, cause = Some(e)) case _ => super.classifyException(message, e) } case unsupported: UnsupportedOperationException => throw unsupported From 229db0e51eb128f171e619b317db98c28222389e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 28 Jan 2022 17:39:55 +0800 Subject: [PATCH 23/53] [SPARK-38035][SQL] Add docker tests for build-in JDBC dialect ### What changes were proposed in this pull request? Currently, Spark only have `PostgresNamespaceSuite` to test DS V2 namespace in docker environment. But missing tests for other build-in JDBC dialect (e.g. Oracle, MySQL). This PR also found some compatible issue. For example, the JDBC api `conn.getMetaData.getSchemas` works bad for MySQL. ### Why are the changes needed? We need add tests for other build-in JDBC dialect. ### Does this PR introduce _any_ user-facing change? 'No'. Just add tests which face developers. ### How was this patch tested? New tests. Closes #35333 from beliefer/SPARK-38035. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- external/docker-integration-tests/pom.xml | 5 + .../spark/sql/jdbc/v2/DB2NamespaceSuite.scala | 74 ++++++++++++ .../jdbc/v2/MsSqlServerNamespaceSuite.scala | 76 ++++++++++++ .../sql/jdbc/v2/MySQLIntegrationSuite.scala | 3 - .../sql/jdbc/v2/MySQLNamespaceSuite.scala | 65 ++++++++++ .../sql/jdbc/v2/OracleNamespaceSuite.scala | 86 ++++++++++++++ .../sql/jdbc/v2/PostgresNamespaceSuite.scala | 6 +- .../sql/jdbc/v2/V2JDBCNamespaceTest.scala | 112 +++++++++++------- .../datasources/jdbc/JdbcUtils.scala | 7 +- .../apache/spark/sql/jdbc/DB2Dialect.scala | 28 ++++- .../apache/spark/sql/jdbc/JdbcDialects.scala | 8 ++ .../spark/sql/jdbc/MsSqlServerDialect.scala | 14 +++ 12 files changed, 429 insertions(+), 55 deletions(-) create mode 100644 external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala create mode 100644 external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala create mode 100644 external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala create mode 100644 external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 14d8da6a1613e..282d64e3459c6 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -162,5 +162,10 @@ mssql-jdbc test + + mysql + mysql-connector-java + test + diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala new file mode 100644 index 0000000000000..f0e98fc2722b0 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.6.0a): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.6.0a + * ./build/sbt -Pdocker-integration-tests "testOnly *v2.DB2NamespaceSuite" + * }}} + */ +@DockerTest +class DB2NamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", "ibmcom/db2:11.5.6.0a") + override val env = Map( + "DB2INST1_PASSWORD" -> "rootpass", + "LICENSE" -> "accept", + "DBNAME" -> "db2foo", + "ARCHIVE_LOGS" -> "false", + "AUTOCONFIG" -> "false" + ) + override val usesIpc = false + override val jdbcPort: Int = 50000 + override val privileged = true + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:db2://$ip:$port/db2foo:user=db2inst1;password=rootpass;retrieveMessagesFromServerOnGetMessage=true;" //scalastyle:ignore + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.ibm.db2.jcc.DB2Driver").asJava) + + catalog.initialize("db2", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("NULLID"), Array("SQLJ"), Array("SYSCAT"), Array("SYSFUN"), + Array("SYSIBM"), Array("SYSIBMADM"), Array("SYSIBMINTERNAL"), Array("SYSIBMTS"), + Array("SYSPROC"), Array("SYSPUBLIC"), Array("SYSSTAT"), Array("SYSTOOLS")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + builtinNamespaces ++ Array(namespace) + } + + override val supportsDropSchemaCascade: Boolean = false + + testListNamespaces() + testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala new file mode 100644 index 0000000000000..aa8dac266380a --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., 2019-CU13-ubuntu-20.04): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04 + * ./build/sbt -Pdocker-integration-tests "testOnly *v2.MsSqlServerNamespaceSuite" + * }}} + */ +@DockerTest +class MsSqlServerNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", + "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") + override val env = Map( + "SA_PASSWORD" -> "Sapass123", + "ACCEPT_EULA" -> "Y" + ) + override val usesIpc = false + override val jdbcPort: Int = 1433 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.microsoft.sqlserver.jdbc.SQLServerDriver").asJava) + + catalog.initialize("mssql", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("db_accessadmin"), Array("db_backupoperator"), Array("db_datareader"), + Array("db_datawriter"), Array("db_ddladmin"), Array("db_denydatareader"), + Array("db_denydatawriter"), Array("db_owner"), Array("db_securityadmin"), Array("dbo"), + Array("guest"), Array("INFORMATION_SCHEMA"), Array("sys")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + builtinNamespaces ++ Array(namespace) + } + + override val supportsSchemaComment: Boolean = false + + override val supportsDropSchemaCascade: Boolean = false + + testListNamespaces() + testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index bc4bf54324ee5..97f521a378eb7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -29,14 +29,11 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * * To run this test suite for a specific version (e.g., mysql:5.7.36): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLIntegrationSuite" - * * }}} - * */ @DockerTest class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala new file mode 100644 index 0000000000000..d3230155b8923 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., mysql:5.7.36): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 + * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLNamespaceSuite" + * }}} + */ +@DockerTest +class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort: Int = 3306 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/" + + s"mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true&useSSL=false" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.mysql.jdbc.Driver").asJava) + + catalog.initialize("mysql", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = Array() + + override val supportsSchemaComment: Boolean = false + + // Cannot get namespaces with conn.getMetaData.getSchemas + // TODO testListNamespaces() + // TODO testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala new file mode 100644 index 0000000000000..31f26d2990666 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * The following are the steps to test this: + * + * 1. Choose to use a prebuilt image or build Oracle database in a container + * - The documentation on how to build Oracle RDBMS in a container is at + * https://github.com/oracle/docker-images/blob/master/OracleDatabase/SingleInstance/README.md + * - Official Oracle container images can be found at https://container-registry.oracle.com + * - A trustable and streamlined Oracle XE database image can be found on Docker Hub at + * https://hub.docker.com/r/gvenzl/oracle-xe see also https://github.com/gvenzl/oci-oracle-xe + * 2. Run: export ORACLE_DOCKER_IMAGE_NAME=image_you_want_to_use_for_testing + * - Example: export ORACLE_DOCKER_IMAGE_NAME=gvenzl/oracle-xe:latest + * 3. Run: export ENABLE_DOCKER_INTEGRATION_TESTS=1 + * 4. Start docker: sudo service docker start + * - Optionally, docker pull $ORACLE_DOCKER_IMAGE_NAME + * 5. Run Spark integration tests for Oracle with: ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.v2.OracleNamespaceSuite" + * + * A sequence of commands to build the Oracle XE database container image: + * $ git clone https://github.com/oracle/docker-images.git + * $ cd docker-images/OracleDatabase/SingleInstance/dockerfiles + * $ ./buildContainerImage.sh -v 18.4.0 -x + * $ export ORACLE_DOCKER_IMAGE_NAME=oracle/database:18.4.0-xe + * + * This procedure has been validated with Oracle 18.4.0 Express Edition. + */ +@DockerTest +class OracleNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + lazy override val imageName = + sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-xe:18.4.0") + val oracle_password = "Th1s1sThe0racle#Pass" + override val env = Map( + "ORACLE_PWD" -> oracle_password, // oracle images uses this + "ORACLE_PASSWORD" -> oracle_password // gvenzl/oracle-xe uses this + ) + override val usesIpc = false + override val jdbcPort: Int = 1521 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:oracle:thin:system/$oracle_password@//$ip:$port/xe" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "oracle.jdbc.OracleDriver").asJava) + + catalog.initialize("system", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("ANONYMOUS"), Array("APEX_030200"), Array("APEX_PUBLIC_USER"), Array("APPQOSSYS"), + Array("BI"), Array("DIP"), Array("FLOWS_FILES"), Array("HR"), Array("OE"), Array("PM"), + Array("SCOTT"), Array("SH"), Array("SPATIAL_CSW_ADMIN_USR"), Array("SPATIAL_WFS_ADMIN_USR"), + Array("XS$NULL")) + + // Cannot create schema dynamically + // TODO testListNamespaces() + // TODO testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index b5cf3dfcb474d..4a615bddd7dfb 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -53,7 +53,9 @@ class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNames override def dataPreparation(conn: Connection): Unit = {} - override def builtinNamespaces: Array[Array[String]] = { + override def builtinNamespaces: Array[Array[String]] = Array(Array("information_schema"), Array("pg_catalog"), Array("public")) - } + + testListNamespaces() + testDropNamespaces() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index 4baed69b79f22..8d97ac45568e3 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -22,7 +22,7 @@ import java.util.Collections import scala.collection.JavaConverters._ -import org.apache.log4j.Level +import org.apache.logging.log4j.Level import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException @@ -44,52 +44,78 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte def builtinNamespaces: Array[Array[String]] - test("listNamespaces: basic behavior") { - catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) - assert(catalog.listNamespaces() === Array(Array("foo")) ++ builtinNamespaces) - assert(catalog.listNamespaces(Array("foo")) === Array()) - assert(catalog.namespaceExists(Array("foo")) === true) - - val logAppender = new LogAppender("catalog comment") - withLogAppender(logAppender) { - catalog.alterNamespace(Array("foo"), NamespaceChange - .setProperty("comment", "comment for foo")) - catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) - } - val createCommentWarning = logAppender.loggingEvents - .filter(_.getLevel == Level.WARN) - .map(_.getRenderedMessage) - .exists(_.contains("catalog comment")) - assert(createCommentWarning === false) - - catalog.dropNamespace(Array("foo")) - assert(catalog.namespaceExists(Array("foo")) === false) - assert(catalog.listNamespaces() === builtinNamespaces) - val msg = intercept[AnalysisException] { - catalog.listNamespaces(Array("foo")) - }.getMessage - assert(msg.contains("Namespace 'foo' not found")) + def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + Array(namespace) ++ builtinNamespaces } - test("Drop namespace") { - val ident1 = Identifier.of(Array("foo"), "tab") - // Drop empty namespace without cascade - catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) - assert(catalog.namespaceExists(Array("foo")) === true) - catalog.dropNamespace(Array("foo"), cascade = false) - assert(catalog.namespaceExists(Array("foo")) === false) - - // Drop non empty namespace without cascade - catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) - assert(catalog.namespaceExists(Array("foo")) === true) - catalog.createTable(ident1, schema, Array.empty, emptyProps) - intercept[NonEmptyNamespaceException] { + def supportsSchemaComment: Boolean = true + + def supportsDropSchemaCascade: Boolean = true + + def testListNamespaces(): Unit = { + test("listNamespaces: basic behavior") { + val commentMap = if (supportsSchemaComment) { + Map("comment" -> "test comment") + } else { + Map.empty[String, String] + } + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.listNamespaces() === listNamespaces(Array("foo"))) + assert(catalog.listNamespaces(Array("foo")) === Array()) + assert(catalog.namespaceExists(Array("foo")) === true) + + if (supportsSchemaComment) { + val logAppender = new LogAppender("catalog comment") + withLogAppender(logAppender) { + catalog.alterNamespace(Array("foo"), NamespaceChange + .setProperty("comment", "comment for foo")) + catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) + } + val createCommentWarning = logAppender.loggingEvents + .filter(_.getLevel == Level.WARN) + .map(_.getMessage.getFormattedMessage) + .exists(_.contains("catalog comment")) + assert(createCommentWarning === false) + } + catalog.dropNamespace(Array("foo"), cascade = false) + assert(catalog.namespaceExists(Array("foo")) === false) + assert(catalog.listNamespaces() === builtinNamespaces) + val msg = intercept[AnalysisException] { + catalog.listNamespaces(Array("foo")) + }.getMessage + assert(msg.contains("Namespace 'foo' not found")) } + } + + def testDropNamespaces(): Unit = { + test("Drop namespace") { + val ident1 = Identifier.of(Array("foo"), "tab") + // Drop empty namespace without cascade + val commentMap = if (supportsSchemaComment) { + Map("comment" -> "test comment") + } else { + Map.empty[String, String] + } + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.dropNamespace(Array("foo"), cascade = false) + assert(catalog.namespaceExists(Array("foo")) === false) - // Drop non empty namespace with cascade - assert(catalog.namespaceExists(Array("foo")) === true) - catalog.dropNamespace(Array("foo"), cascade = true) - assert(catalog.namespaceExists(Array("foo")) === false) + // Drop non empty namespace without cascade + catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.createTable(ident1, schema, Array.empty, emptyProps) + intercept[NonEmptyNamespaceException] { + catalog.dropNamespace(Array("foo"), cascade = false) + } + + // Drop non empty namespace with cascade + if (supportsDropSchemaCascade) { + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.dropNamespace(Array("foo"), cascade = true) + assert(catalog.namespaceExists(Array("foo")) === false) + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 5d62b96c9ce53..56bfa80d1bb7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -1022,12 +1022,7 @@ object JdbcUtils extends Logging { def dropNamespace( conn: Connection, options: JDBCOptions, namespace: String, cascade: Boolean): Unit = { val dialect = JdbcDialects.get(options.url) - val dropCmd = if (cascade) { - s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)} CASCADE" - } else { - s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)}" - } - executeStatement(conn, options, dropCmd) + executeStatement(conn, options, dialect.dropSchema(namespace, cascade)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 9e9aac679ab39..ffda7545c6e9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{SQLException, Types} import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ @@ -92,4 +94,28 @@ private object DB2Dialect extends JdbcDialect { val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL" s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable" } + + override def removeSchemaCommentQuery(schema: String): String = { + s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS ''" + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getSQLState match { + // https://www.ibm.com/docs/en/db2/11.5?topic=messages-sqlstate + case "42893" => throw NonEmptyNamespaceException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case _ => super.classifyException(message, e) + } + } + + override def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)} CASCADE" + } else { + s"DROP SCHEMA ${quoteIdentifier(schema)} RESTRICT" + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index dbf5e4c037d31..7579be2ba20c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -327,6 +327,14 @@ abstract class JdbcDialect extends Serializable with Logging{ s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS NULL" } + def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)} CASCADE" + } else { + s"DROP SCHEMA ${quoteIdentifier(schema)}" + } + } + /** * Build a create index SQL statement. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 442c5599b3ab3..3d8a48a66ea8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.jdbc +import java.sql.SQLException import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -147,4 +150,15 @@ private object MsSqlServerDialect extends JdbcDialect { override def getLimitClause(limit: Integer): String = { "" } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getErrorCode match { + case 3729 => throw NonEmptyNamespaceException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case _ => super.classifyException(message, e) + } + } } From 2227b1373f60bfb402f53aa2e0d84626570b4746 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 10 Feb 2022 21:32:18 +0800 Subject: [PATCH 24/53] [SPARK-38054][SQL] Supports list namespaces in JDBC v2 MySQL dialect ### What changes were proposed in this pull request? Currently, `JDBCTableCatalog.scala` query namespaces show below. ``` val schemaBuilder = ArrayBuilder.make[Array[String]] val rs = conn.getMetaData.getSchemas() while (rs.next()) { schemaBuilder += Array(rs.getString(1)) } schemaBuilder.result ``` But the code cannot get any information when using MySQL JDBC driver. This PR uses `SHOW SCHEMAS` to query namespaces of MySQL. This PR also fix other issues below: - Release the docker tests in `MySQLNamespaceSuite.scala`. - Because MySQL doesn't support create comment of schema, let's throws `SQLFeatureNotSupportedException`. - Because MySQL doesn't support `DROP SCHEMA` in `RESTRICT` mode, let's throws `SQLFeatureNotSupportedException`. - Reactor `JdbcUtils.executeQuery` to avoid `java.sql.SQLException: Operation not allowed after ResultSet closed`. ### Why are the changes needed? MySQL dialect supports query namespaces. ### Does this PR introduce _any_ user-facing change? 'Yes'. Some API changed. ### How was this patch tested? New tests. Closes #35355 from beliefer/SPARK-38054. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../sql/jdbc/v2/MySQLNamespaceSuite.scala | 48 ++++++++++-- .../sql/jdbc/v2/V2JDBCNamespaceTest.scala | 22 ++++-- .../sql/errors/QueryExecutionErrors.scala | 12 +++ .../datasources/jdbc/JdbcUtils.scala | 66 +++++++++------- .../v2/jdbc/JDBCTableCatalog.scala | 26 +++---- .../apache/spark/sql/jdbc/JdbcDialects.scala | 41 +++++++++- .../apache/spark/sql/jdbc/MySQLDialect.scala | 78 ++++++++++++++----- 7 files changed, 218 insertions(+), 75 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala index d3230155b8923..d8dee61d70ea6 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.jdbc.v2 -import java.sql.Connection +import java.sql.{Connection, SQLFeatureNotSupportedException} import scala.collection.JavaConverters._ +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.NamespaceChange import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest @@ -55,11 +57,47 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac override def dataPreparation(conn: Connection): Unit = {} - override def builtinNamespaces: Array[Array[String]] = Array() + override def builtinNamespaces: Array[Array[String]] = + Array(Array("information_schema"), Array("mysql"), Array("performance_schema"), Array("sys")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + Array(builtinNamespaces.head, namespace) ++ builtinNamespaces.tail + } override val supportsSchemaComment: Boolean = false - // Cannot get namespaces with conn.getMetaData.getSchemas - // TODO testListNamespaces() - // TODO testDropNamespaces() + override val supportsDropSchemaRestrict: Boolean = false + + testListNamespaces() + testDropNamespaces() + + test("Create or remove comment of namespace unsupported") { + val e1 = intercept[AnalysisException] { + catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) + } + assert(e1.getMessage.contains("Failed create name space: foo")) + assert(e1.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e1.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Create namespace comment is not supported")) + assert(catalog.namespaceExists(Array("foo")) === false) + catalog.createNamespace(Array("foo"), Map.empty[String, String].asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + val e2 = intercept[AnalysisException] { + catalog.alterNamespace(Array("foo"), NamespaceChange + .setProperty("comment", "comment for foo")) + } + assert(e2.getMessage.contains("Failed create comment on name space: foo")) + assert(e2.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e2.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Create namespace comment is not supported")) + val e3 = intercept[AnalysisException] { + catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) + } + assert(e3.getMessage.contains("Failed remove comment on name space: foo")) + assert(e3.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e3.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Remove namespace comment is not supported")) + catalog.dropNamespace(Array("foo"), cascade = true) + assert(catalog.namespaceExists(Array("foo")) === false) + } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index 8d97ac45568e3..bae0d7c361635 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -52,6 +52,8 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte def supportsDropSchemaCascade: Boolean = true + def supportsDropSchemaRestrict: Boolean = true + def testListNamespaces(): Unit = { test("listNamespaces: basic behavior") { val commentMap = if (supportsSchemaComment) { @@ -78,7 +80,11 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte assert(createCommentWarning === false) } - catalog.dropNamespace(Array("foo"), cascade = false) + if (supportsDropSchemaRestrict) { + catalog.dropNamespace(Array("foo"), cascade = false) + } else { + catalog.dropNamespace(Array("foo"), cascade = true) + } assert(catalog.namespaceExists(Array("foo")) === false) assert(catalog.listNamespaces() === builtinNamespaces) val msg = intercept[AnalysisException] { @@ -99,15 +105,21 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte } catalog.createNamespace(Array("foo"), commentMap.asJava) assert(catalog.namespaceExists(Array("foo")) === true) - catalog.dropNamespace(Array("foo"), cascade = false) + if (supportsDropSchemaRestrict) { + catalog.dropNamespace(Array("foo"), cascade = false) + } else { + catalog.dropNamespace(Array("foo"), cascade = true) + } assert(catalog.namespaceExists(Array("foo")) === false) // Drop non empty namespace without cascade - catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) + catalog.createNamespace(Array("foo"), commentMap.asJava) assert(catalog.namespaceExists(Array("foo")) === true) catalog.createTable(ident1, schema, Array.empty, emptyProps) - intercept[NonEmptyNamespaceException] { - catalog.dropNamespace(Array("foo"), cascade = false) + if (supportsDropSchemaRestrict) { + intercept[NonEmptyNamespaceException] { + catalog.dropNamespace(Array("foo"), cascade = false) + } } // Drop non empty namespace with cascade diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 7f77243af8a88..88ab9e530a1a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1804,4 +1804,16 @@ object QueryExecutionErrors { def pivotNotAfterGroupByUnsupportedError(): Throwable = { new UnsupportedOperationException("pivot is only supported after a groupBy") } + + def unsupportedCreateNamespaceCommentError(): Throwable = { + new SQLFeatureNotSupportedException("Create namespace comment is not supported") + } + + def unsupportedRemoveNamespaceCommentError(): Throwable = { + new SQLFeatureNotSupportedException("Remove namespace comment is not supported") + } + + def unsupportedDropNamespaceRestrictError(): Throwable = { + new SQLFeatureNotSupportedException("Drop namespace restrict is not supported") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 56bfa80d1bb7f..b554814f1e193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -976,53 +976,57 @@ object JdbcUtils extends Logging { } /** - * Creates a namespace. + * Creates a schema. */ - def createNamespace( + def createSchema( conn: Connection, options: JDBCOptions, - namespace: String, + schema: String, comment: String): Unit = { + val statement = conn.createStatement + try { + statement.setQueryTimeout(options.queryTimeout) + val dialect = JdbcDialects.get(options.url) + dialect.createSchema(statement, schema, comment) + } finally { + statement.close() + } + } + + def schemaExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { + val dialect = JdbcDialects.get(options.url) + dialect.schemasExists(conn, options, schema) + } + + def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { val dialect = JdbcDialects.get(options.url) - executeStatement(conn, options, s"CREATE SCHEMA ${dialect.quoteIdentifier(namespace)}") - if (!comment.isEmpty) createNamespaceComment(conn, options, namespace, comment) + dialect.listSchemas(conn, options) } - def createNamespaceComment( + def alterSchemaComment( conn: Connection, options: JDBCOptions, - namespace: String, + schema: String, comment: String): Unit = { val dialect = JdbcDialects.get(options.url) - try { - executeStatement( - conn, options, dialect.getSchemaCommentQuery(namespace, comment)) - } catch { - case e: Exception => - logWarning("Cannot create JDBC catalog comment. The catalog comment will be ignored.") - } + executeStatement(conn, options, dialect.getSchemaCommentQuery(schema, comment)) } - def removeNamespaceComment( + def removeSchemaComment( conn: Connection, options: JDBCOptions, - namespace: String): Unit = { + schema: String): Unit = { val dialect = JdbcDialects.get(options.url) - try { - executeStatement(conn, options, dialect.removeSchemaCommentQuery(namespace)) - } catch { - case e: Exception => - logWarning("Cannot drop JDBC catalog comment.") - } + executeStatement(conn, options, dialect.removeSchemaCommentQuery(schema)) } /** - * Drops a namespace from the JDBC database. + * Drops a schema from the JDBC database. */ - def dropNamespace( - conn: Connection, options: JDBCOptions, namespace: String, cascade: Boolean): Unit = { + def dropSchema( + conn: Connection, options: JDBCOptions, schema: String, cascade: Boolean): Unit = { val dialect = JdbcDialects.get(options.url) - executeStatement(conn, options, dialect.dropSchema(namespace, cascade)) + executeStatement(conn, options, dialect.dropSchema(schema, cascade)) } /** @@ -1154,11 +1158,17 @@ object JdbcUtils extends Logging { } } - def executeQuery(conn: Connection, options: JDBCOptions, sql: String): ResultSet = { + def executeQuery(conn: Connection, options: JDBCOptions, sql: String)( + f: ResultSet => Unit): Unit = { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - statement.executeQuery(sql) + val rs = statement.executeQuery(sql) + try { + f(rs) + } finally { + rs.close() + } } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index 58ad3f98d120b..f311cf63d1419 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -21,7 +21,6 @@ import java.util import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuilder import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange} @@ -173,23 +172,14 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def namespaceExists(namespace: Array[String]): Boolean = namespace match { case Array(db) => JdbcUtils.withConnection(options) { conn => - val rs = conn.getMetaData.getSchemas(null, db) - while (rs.next()) { - if (rs.getString(1) == db) return true; - } - false + JdbcUtils.schemaExists(conn, options, db) } case _ => false } override def listNamespaces(): Array[Array[String]] = { JdbcUtils.withConnection(options) { conn => - val schemaBuilder = ArrayBuilder.make[Array[String]] - val rs = conn.getMetaData.getSchemas() - while (rs.next()) { - schemaBuilder += Array(rs.getString(1)) - } - schemaBuilder.result + JdbcUtils.listSchemas(conn, options) } } @@ -236,7 +226,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } JdbcUtils.withConnection(options) { conn => JdbcUtils.classifyException(s"Failed create name space: $db", dialect) { - JdbcUtils.createNamespace(conn, options, db, comment) + JdbcUtils.createSchema(conn, options, db, comment) } } @@ -254,7 +244,9 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging case set: NamespaceChange.SetProperty => if (set.property() == SupportsNamespaces.PROP_COMMENT) { JdbcUtils.withConnection(options) { conn => - JdbcUtils.createNamespaceComment(conn, options, db, set.value) + JdbcUtils.classifyException(s"Failed create comment on name space: $db", dialect) { + JdbcUtils.alterSchemaComment(conn, options, db, set.value) + } } } else { throw QueryCompilationErrors.cannotSetJDBCNamespaceWithPropertyError(set.property) @@ -263,7 +255,9 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging case unset: NamespaceChange.RemoveProperty => if (unset.property() == SupportsNamespaces.PROP_COMMENT) { JdbcUtils.withConnection(options) { conn => - JdbcUtils.removeNamespaceComment(conn, options, db) + JdbcUtils.classifyException(s"Failed remove comment on name space: $db", dialect) { + JdbcUtils.removeSchemaComment(conn, options, db) + } } } else { throw QueryCompilationErrors.cannotUnsetJDBCNamespaceWithPropertyError(unset.property) @@ -282,7 +276,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging case Array(db) if namespaceExists(namespace) => JdbcUtils.withConnection(options) { conn => JdbcUtils.classifyException(s"Failed drop name space: $db", dialect) { - JdbcUtils.dropNamespace(conn, options, db, cascade) + JdbcUtils.dropSchema(conn, options, db, cascade) true } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 7579be2ba20c8..23cdf25a86652 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, Date, Timestamp} +import java.sql.{Connection, Date, Statement, Timestamp} import java.time.{Instant, LocalDate} import java.util @@ -229,6 +229,45 @@ abstract class JdbcDialect extends Serializable with Logging{ } } + /** + * Create schema with an optional comment. Empty string means no comment. + */ + def createSchema(statement: Statement, schema: String, comment: String): Unit = { + val schemaCommentQuery = if (comment.nonEmpty) { + // We generate comment query here so that it can fail earlier without creating the schema. + getSchemaCommentQuery(schema, comment) + } else { + comment + } + statement.executeUpdate(s"CREATE SCHEMA ${quoteIdentifier(schema)}") + if (comment.nonEmpty) { + statement.executeUpdate(schemaCommentQuery) + } + } + + /** + * Check schema exists or not. + */ + def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { + val rs = conn.getMetaData.getSchemas(null, schema) + while (rs.next()) { + if (rs.getString(1) == schema) return true; + } + false + } + + /** + * Lists all the schemas in this table. + */ + def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { + val schemaBuilder = ArrayBuilder.make[Array[String]] + val rs = conn.getMetaData.getSchemas() + while (rs.next()) { + schemaBuilder += Array(rs.getString(1)) + } + schemaBuilder.result + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 8316b3c04e107..c32499b5f32e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -21,6 +21,8 @@ import java.sql.{Connection, SQLException, Types} import java.util import java.util.Locale +import scala.collection.mutable.ArrayBuilder + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} @@ -76,6 +78,25 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { s"`$colName`" } + override def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { + listSchemas(conn, options).exists(_.head == schema) + } + + override def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { + val schemaBuilder = ArrayBuilder.make[Array[String]] + try { + JdbcUtils.executeQuery(conn, options, "SHOW SCHEMAS") { rs => + while (rs.next()) { + schemaBuilder += Array(rs.getString("Database")) + } + } + } catch { + case _: Exception => + logWarning("Cannot show schemas.") + } + schemaBuilder.result + } + override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } @@ -134,6 +155,14 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { case _ => JdbcUtils.getCommonJDBCType(dt) } + override def getSchemaCommentQuery(schema: String, comment: String): String = { + throw QueryExecutionErrors.unsupportedCreateNamespaceCommentError() + } + + override def removeSchemaCommentQuery(schema: String): String = { + throw QueryExecutionErrors.unsupportedRemoveNamespaceCommentError() + } + // CREATE INDEX syntax // https://dev.mysql.com/doc/refman/8.0/en/create-index.html override def createIndex( @@ -175,26 +204,27 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { val sql = s"SHOW INDEXES FROM $tableName" var indexMap: Map[String, TableIndex] = Map() try { - val rs = JdbcUtils.executeQuery(conn, options, sql) - while (rs.next()) { - val indexName = rs.getString("key_name") - val colName = rs.getString("column_name") - val indexType = rs.getString("index_type") - val indexComment = rs.getString("Index_comment") - if (indexMap.contains(indexName)) { - val index = indexMap.get(indexName).get - val newIndex = new TableIndex(indexName, indexType, - index.columns() :+ FieldReference(colName), - index.columnProperties, index.properties) - indexMap += (indexName -> newIndex) - } else { - // The only property we are building here is `COMMENT` because it's the only one - // we can get from `SHOW INDEXES`. - val properties = new util.Properties(); - if (indexComment.nonEmpty) properties.put("COMMENT", indexComment) - val index = new TableIndex(indexName, indexType, Array(FieldReference(colName)), - new util.HashMap[NamedReference, util.Properties](), properties) - indexMap += (indexName -> index) + JdbcUtils.executeQuery(conn, options, sql) { rs => + while (rs.next()) { + val indexName = rs.getString("key_name") + val colName = rs.getString("column_name") + val indexType = rs.getString("index_type") + val indexComment = rs.getString("Index_comment") + if (indexMap.contains(indexName)) { + val index = indexMap.get(indexName).get + val newIndex = new TableIndex(indexName, indexType, + index.columns() :+ FieldReference(colName), + index.columnProperties, index.properties) + indexMap += (indexName -> newIndex) + } else { + // The only property we are building here is `COMMENT` because it's the only one + // we can get from `SHOW INDEXES`. + val properties = new util.Properties(); + if (indexComment.nonEmpty) properties.put("COMMENT", indexComment) + val index = new TableIndex(indexName, indexType, Array(FieldReference(colName)), + new util.HashMap[NamedReference, util.Properties](), properties) + indexMap += (indexName -> index) + } } } } catch { @@ -219,4 +249,12 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { case _ => super.classifyException(message, e) } } + + override def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)}" + } else { + throw QueryExecutionErrors.unsupportedDropNamespaceRestrictError() + } + } } From b0e5d0e1631db21fcf1ff3fa7fb32933ab019be6 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 2 Sep 2021 19:11:43 -0700 Subject: [PATCH 25/53] [SPARK-36351][SQL] Refactor filter push down in file source v2 ### What changes were proposed in this pull request? Currently in `V2ScanRelationPushDown`, we push the filters (partition filters + data filters) to file source, and then pass all the filters (partition filters + data filters) as post scan filters to v2 Scan, and later in `PruneFileSourcePartitions`, we separate partition filters and data filters, set them in the format of `Expression` to file source. Changes in this PR: When we push filters to file sources in `V2ScanRelationPushDown`, since we already have the information about partition column , we want to separate partition filter and data filter there. The benefit of doing this: - we can handle all the filter related work for v2 file source at one place instead of two (`V2ScanRelationPushDown` and `PruneFileSourcePartitions`), so the code will be cleaner and easier to maintain. - we actually have to separate partition filters and data filters at `V2ScanRelationPushDown`, otherwise, there is no way to find out which filters are partition filters, and we can't push down aggregate for parquet even if we only have partition filter. - By separating the filters early at `V2ScanRelationPushDown`, we only needs to check data filters to find out which one needs to be converted to data source filters (e.g. Parquet predicates, ORC predicates) and pushed down to file source, right now we are checking all the filters (both partition filters and data filters) - Similarly, we can only pass data filters as post scan filters to v2 Scan, because partition filters are used for partition pruning only, no need to pass them as post scan filters. In order to do this, we will have the following changes - add `pushFilters` in file source v2. In this method: - push both Expression partition filter and Expression data filter to file source. Have to use Expression filters because we need these for partition pruning. - data filters are used for filter push down. If file source needs to push down data filters, it translates the data filters from `Expression` to `Sources.Filer`, and then decides which filters to push down. - partition filters are used for partition pruning. - file source v2 no need to implement `SupportsPushdownFilters` any more, because when we separating the two types of filters, we have already set them on file data sources. It's redundant to use `SupportsPushdownFilters` to set the filters again on file data sources. ### Why are the changes needed? see section one ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #33650 from huaxingao/partition_filter. Authored-by: Huaxin Gao Signed-off-by: Liang-Chi Hsieh --- .../apache/spark/sql/v2/avro/AvroScan.scala | 4 -- .../spark/sql/v2/avro/AvroScanBuilder.scala | 19 +++---- .../SupportsPushDownCatalystFilters.scala | 41 ++++++++++++++ .../datasources/DataSourceUtils.scala | 21 ++++++- .../PruneFileSourcePartitions.scala | 56 ++----------------- .../execution/datasources/v2/FileScan.scala | 6 -- .../datasources/v2/FileScanBuilder.scala | 44 +++++++++++++-- .../datasources/v2/PushDownUtils.scala | 4 ++ .../datasources/v2/csv/CSVScan.scala | 6 +- .../datasources/v2/csv/CSVScanBuilder.scala | 19 +++---- .../datasources/v2/json/JsonScan.scala | 6 +- .../datasources/v2/json/JsonScanBuilder.scala | 19 +++---- .../datasources/v2/orc/OrcScan.scala | 4 -- .../datasources/v2/orc/OrcScanBuilder.scala | 19 +++---- .../datasources/v2/parquet/ParquetScan.scala | 4 -- .../v2/parquet/ParquetScanBuilder.scala | 15 ++--- .../datasources/v2/text/TextScan.scala | 6 +- .../datasources/v2/text/TextScanBuilder.scala | 3 +- .../datasources/json/JsonSuite.scala | 6 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 18 ++++++ 20 files changed, 176 insertions(+), 144 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index 144e9ad129feb..d0f38c12427c3 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -62,10 +62,6 @@ case class AvroScan( pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options && equivalentFilters(pushedFilters, a.pushedFilters) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala index 9420608bb22ce..8fae89a945826 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.v2.avro import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -31,7 +31,7 @@ class AvroScanBuilder ( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { AvroScan( @@ -41,17 +41,16 @@ class AvroScanBuilder ( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.avroFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala new file mode 100644 index 0000000000000..9c2a4ac78a24a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.sources.Filter + +/** + * A mix-in interface for {@link FileScanBuilder}. File sources can implement this interface to + * push down filters to the file source. The pushed down filters will be separated into partition + * filters and data filters. Partition filters are used for partition pruning and data filters are + * used to reduce the size of the data to be read. + */ +trait SupportsPushDownCatalystFilters { + + /** + * Pushes down catalyst Expression filters (which will be separated into partition filters and + * data filters), and returns data filters that need to be evaluated after scanning. + */ + def pushFilters(filters: Seq[Expression]): Seq[Expression] + + /** + * Returns the data filters that are pushed to the data source via + * {@link #pushFilters(Expression[])}. + */ + def pushedFilters: Array[Filter] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index fcd95a27bf8ca..67d03998a2a24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -28,6 +28,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.SparkUpgradeException import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_LEGACY_INT96, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions @@ -39,7 +40,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils -object DataSourceUtils { +object DataSourceUtils extends PredicateHelper { /** * The key to use for storing partitionBy columns as options. */ @@ -242,4 +243,22 @@ object DataSourceUtils { options } } + + def getPartitionFiltersAndDataFilters( + partitionSchema: StructType, + normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val partitionColumns = normalizedFilters.flatMap { expr => + expr.collect { + case attr: AttributeReference if partitionSchema.names.contains(attr.name) => + attr + } + } + val partitionSet = AttributeSet(partitionColumns) + val (partitionFilters, dataFilters) = normalizedFilters.partition(f => + f.references.subsetOf(partitionSet) + ) + val extraPartitionFilter = + dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) + (ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 0927027bee0bc..2e8e5426d47be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,52 +17,24 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan} -import org.apache.spark.sql.types.StructType /** * Prune the partitions of file source based table using partition filters. Currently, this rule - * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]] - * with [[FileScan]]. + * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]]. * * For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding * statistics will be updated. And the partition filters will be kept in the filters of returned * logical plan. - * - * For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to - * its underlying [[FileScan]]. And the partition filters will be removed in the filters of - * returned logical plan. */ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] with PredicateHelper { - private def getPartitionKeyFiltersAndDataFilters( - sparkSession: SparkSession, - relation: LeafNode, - partitionSchema: StructType, - filters: Seq[Expression], - output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = { - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) - val partitionColumns = - relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) - val partitionSet = AttributeSet(partitionColumns) - val (partitionFilters, dataFilters) = normalizedFilters.partition(f => - f.references.subsetOf(partitionSet) - ) - val extraPartitionFilter = - dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) - - (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters) - } - private def rebuildPhysicalOperation( projects: Seq[NamedExpression], filters: Seq[Expression], @@ -91,12 +63,14 @@ private[sql] object PruneFileSourcePartitions _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, filters, + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), logicalRelation.output) + val (partitionKeyFilters, _) = DataSourceUtils + .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) if (partitionKeyFilters.nonEmpty) { - val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) + val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession) // Change table stats based on the sizeInBytes of pruned files @@ -117,23 +91,5 @@ private[sql] object PruneFileSourcePartitions } else { op } - - case op @ PhysicalOperation(projects, filters, - v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) - if filters.nonEmpty => - val (partitionKeyFilters, dataFilters) = - getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation, - scan.readPartitionSchema, filters, output) - // The dataFilters are pushed down only once - if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) { - val prunedV2Relation = - v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters)) - // The pushed down partition filters don't need to be reevaluated. - val afterScanFilters = - ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) - rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) - } else { - op - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 4506bd3d49b5b..0212cdf63fcf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -69,12 +69,6 @@ trait FileScan extends Scan */ def dataFilters: Seq[Expression] - /** - * Create a new `FileScan` instance from the current one - * with different `partitionFilters` and `dataFilters` - */ - def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan - /** * If a file with `path` is unsplittable, return the unsplittable reason, * otherwise return `None`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 97874e8f4932e..309f045201140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -16,19 +16,30 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.SparkSession +import scala.collection.mutable + +import org.apache.spark.sql.{sources, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType abstract class FileScanBuilder( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns { + dataSchema: StructType) + extends ScanBuilder + with SupportsPushDownRequiredColumns + with SupportsPushDownCatalystFilters { private val partitionSchema = fileIndex.partitionSchema private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis protected val supportsNestedSchemaPruning = false protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields) + protected var partitionFilters = Seq.empty[Expression] + protected var dataFilters = Seq.empty[Expression] + protected var pushedDataFilters = Array.empty[Filter] override def pruneColumns(requiredSchema: StructType): Unit = { // [SPARK-30107] While `requiredSchema` might have pruned nested columns, @@ -48,7 +59,7 @@ abstract class FileScanBuilder( StructType(fields) } - protected def readPartitionSchema(): StructType = { + def readPartitionSchema(): StructType = { val requiredNameSet = createRequiredNameSet() val fields = partitionSchema.fields.filter { field => val colName = PartitioningUtils.getColName(field, isCaseSensitive) @@ -57,6 +68,31 @@ abstract class FileScanBuilder( StructType(fields) } + override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { + val (partitionFilters, dataFilters) = + DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filters) + this.partitionFilters = partitionFilters + this.dataFilters = dataFilters + val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] + for (filterExpr <- dataFilters) { + val translated = DataSourceStrategy.translateFilter(filterExpr, true) + if (translated.nonEmpty) { + translatedFilters += translated.get + } + } + pushedDataFilters = pushDataFilters(translatedFilters.toArray) + dataFilters + } + + override def pushedFilters: Array[Filter] = pushedDataFilters + + /* + * Push down data filters to the file source, so the data filters can be evaluated there to + * reduce the size of the data to be read. By default, data filters are not pushed down. + * File source needs to implement this method to push down data filters. + */ + protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter] + private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index b54917e49ed3d..2bffa761dd9e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -97,6 +97,10 @@ object PushDownUtils extends PredicateHelper { } (Right(r.pushedFilters), (untranslatableExprs ++ postScanFilters).toSeq) + case f: FileScanBuilder => + val postScanFilters = f.pushFilters(filters) + (Left(f.pushedFilters), postScanFilters) + case _ => (Left(Nil), filters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 3f77b2147f9ca..cc3c146106670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -84,10 +84,6 @@ case class CSVScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options && equivalentFilters(pushedFilters, c.pushedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala index f7a79bf31948e..2b6edd4f357ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.csv import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -32,7 +32,7 @@ case class CSVScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { CSVScan( @@ -42,17 +42,16 @@ case class CSVScanBuilder( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.csvFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index 29eb8bec9a589..9ab367136fc97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.json.JsonDataSource -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -83,10 +83,6 @@ case class JsonScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options && equivalentFilters(pushedFilters, j.pushedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala index cf1204566ddbd..c581617a4b7e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.json import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -31,7 +31,7 @@ class JsonScanBuilder ( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { JsonScan( sparkSession, @@ -40,17 +40,16 @@ class JsonScanBuilder ( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.jsonFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 8fa7f8dc41ead..7619e3c503139 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -68,8 +68,4 @@ case class OrcScan( override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index dc59526bb316b..cfa396f5482f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2.orc import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder @@ -35,7 +35,7 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -45,20 +45,17 @@ case class OrcScanBuilder( override protected val supportsNestedSchemaPruning: Boolean = true override def build(): Scan = { - OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, - readDataSchema(), readPartitionSchema(), options, pushedFilters()) + OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), + readPartitionSchema(), options, pushedDataFilters, partitionFilters, dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { val dataTypeMap = OrcFilters.getSearchableTypeMap( readDataSchema(), SQLConf.get.caseSensitiveAnalysis) - _pushedFilters = OrcFilters.convertibleFilters(dataTypeMap, filters).toArray + OrcFilters.convertibleFilters(dataTypeMap, dataFilters).toArray + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 60573ba10ccb6..e277e334845c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -105,8 +105,4 @@ case class ParquetScan( override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 4b3f4e7edca6c..ff5137e928db3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder @@ -35,7 +35,7 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -63,17 +63,12 @@ case class ParquetScanBuilder( // The rebase mode doesn't matter here because the filters are used to determine // whether they is convertible. LegacyBehaviorPolicy.CORRECTED) - parquetFilters.convertibleFilters(this.filters).toArray + parquetFilters.convertibleFilters(pushedDataFilters).toArray } override protected val supportsNestedSchemaPruning: Boolean = true - private var filters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - this.filters - } + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters // Note: for Parquet, the actual filter push down happens in [[ParquetPartitionReaderFactory]]. // It requires the Parquet physical schema to determine whether a filter is convertible. @@ -82,6 +77,6 @@ case class ParquetScanBuilder( override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options) + readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index e75de2c4a4079..3582978a8c569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.text.TextOptions -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -71,10 +71,6 @@ case class TextScan( readPartitionSchema, textOptions) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case t: TextScan => super.equals(t) && options == t.options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala index b2b518c12b01a..0ebb098bfc1df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala @@ -33,6 +33,7 @@ case class TextScanBuilder( extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - TextScan(sparkSession, fileIndex, readDataSchema(), readPartitionSchema(), options) + TextScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options, + partitionFilters, dataFilters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 58921485b207d..e71f3b8c35e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2965,16 +2965,14 @@ class JsonV2Suite extends JsonSuite { withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) - assert(scanBuilder.pushFilters(filters) === filters) - assert(scanBuilder.pushedFilters() === filters) + assert(scanBuilder.pushDataFilters(filters) === filters) } } withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "false") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) - assert(scanBuilder.pushFilters(filters) === filters) - assert(scanBuilder.pushedFilters() === Array.empty[sources.Filter]) + assert(scanBuilder.pushDataFilters(filters) === Array.empty[sources.Filter]) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 637e01c260c99..a65d689385e63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -911,4 +911,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Row("david", 10000.00, 10000.000000, 1), Row("jen", 12000.00, 12000.000000, 1))) } + + test("scan with aggregate push-down: aggregate with partially pushed down filters" + + "will NOT push down") { + val df = spark.table("h2.test.employee") + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val query = df.select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter("SALARY > 100") + .filter(name($"shortName")) + .agg(sum($"SALARY").as("sum_salary")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: []" + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(29000.0))) + } } From ff1a45751ae8b6833304df53f149a50d154e800a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 10 Oct 2021 22:20:09 -0700 Subject: [PATCH 26/53] [SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet ### What changes were proposed in this pull request? Push down Min/Max/Count to Parquet with the following restrictions: - nested types such as Array, Map or Struct will not be pushed down - Timestamp not pushed down because INT96 sort order is undefined, Parquet doesn't return statistics for INT96 - If the aggregate column is on partition column, only Count will be pushed, Min or Max will not be pushed down because Parquet doesn't return max/min for partition column. - If somehow the file doesn't have stats for the aggregate columns, Spark will throw Exception. - Currently, if filter/GROUP BY is involved, Min/Max/Count will not be pushed down, but the restriction will be lifted if the filter or GROUP BY is on partition column (https://issues.apache.org/jira/browse/SPARK-36646 and https://issues.apache.org/jira/browse/SPARK-36647) ### Why are the changes needed? Since parquet has the statistics information for min, max and count, we want to take advantage of this info and push down Min/Max/Count to parquet layer for better performance. ### Does this PR introduce _any_ user-facing change? Yes, `SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED` was added. If sets to true, we will push down Min/Max/Count to Parquet. ### How was this patch tested? new test suites Closes #33639 from huaxingao/parquet_agg. Authored-by: Huaxin Gao Signed-off-by: Liang-Chi Hsieh --- .../apache/spark/sql/internal/SQLConf.scala | 10 + .../apache/spark/sql/types/StructType.scala | 2 +- .../datasources/parquet/ParquetUtils.scala | 227 ++++++++ .../datasources/v2/FileScanBuilder.scala | 2 +- .../ParquetPartitionReaderFactory.scala | 123 ++++- .../datasources/v2/parquet/ParquetScan.scala | 37 +- .../v2/parquet/ParquetScanBuilder.scala | 96 +++- .../org/apache/spark/sql/FileScanSuite.scala | 2 +- .../ParquetAggregatePushDownSuite.scala | 518 ++++++++++++++++++ 9 files changed, 984 insertions(+), 33 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 15927a9ffdfbf..cc63aeeb5e3bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -851,6 +851,14 @@ object SQLConf { .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") .createWithDefault(10) + val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") + .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + + " down to Parquet for optimization. MAX/MIN/COUNT for complex types and timestamp" + + " can't be pushed down") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + "values will be written in Apache Parquet's fixed-length byte array format, which other " + @@ -3679,6 +3687,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownInFilterThreshold: Int = getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index c9862cb629cff..50b197fb9aea3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def names: Array[String] = fieldNames private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap override def equals(that: Any): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b91d75c55c513..1093f9c5aa51b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -16,11 +16,28 @@ */ package org.apache.spark.sql.execution.datasources.parquet +import java.util + +import scala.collection.mutable +import scala.language.existentials + import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object ParquetUtils { def inferSchema( @@ -127,4 +144,214 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to + * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want + * to get the partial aggregates (Max/Min/Count) result using the statistics information + * from Parquet footer file, and then construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + private[sql] def createAggInternalRowFromFooter( + footer: ParquetMetadata, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + isCaseSensitive: Boolean): InternalRow = { + val (primitiveTypes, values) = getPushedDownAggResult( + footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive) + + val builder = Types.buildMessage + primitiveTypes.foreach(t => builder.addField(t)) + val parquetSchema = builder.named("root") + + val schemaConverter = new ParquetToSparkSchemaConverter + val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, + None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) + val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName) + primitiveTypeNames.zipWithIndex.foreach { + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + val v = values(i).asInstanceOf[Boolean] + converter.getConverter(i).asPrimitiveConverter.addBoolean(v) + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + val v = values(i).asInstanceOf[Integer] + converter.getConverter(i).asPrimitiveConverter.addInt(v) + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + val v = values(i).asInstanceOf[Long] + converter.getConverter(i).asPrimitiveConverter.addLong(v) + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + val v = values(i).asInstanceOf[Float] + converter.getConverter(i).asPrimitiveConverter.addFloat(v) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + val v = values(i).asInstanceOf[Double] + converter.getConverter(i).asPrimitiveConverter.addDouble(v) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + val v = values(i).asInstanceOf[Binary] + converter.getConverter(i).asPrimitiveConverter.addBinary(v) + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + val v = values(i).asInstanceOf[Binary] + converter.getConverter(i).asPrimitiveConverter.addBinary(v) + case (_, i) => + throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i)) + } + converter.currentRecord + } + + /** + * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of + * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader + * to read data from Parquet and aggregate at Spark layer. Instead we want + * to get the aggregates (Max/Min/Count) result using the statistics information + * from Parquet footer file, and then construct a ColumnarBatch from these aggregate results. + * + * @return Aggregate results in the format of ColumnarBatch + */ + private[sql] def createAggColumnarBatchFromFooter( + footer: ParquetMetadata, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + offHeap: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + isCaseSensitive: Boolean): ColumnarBatch = { + val row = createAggInternalRowFromFooter( + footer, + filePath, + dataSchema, + partitionSchema, + aggregation, + aggSchema, + datetimeRebaseMode, + isCaseSensitive) + val converter = new RowToColumnConverter(aggSchema) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(1, aggSchema) + } else { + OnHeapColumnVector.allocateColumns(1, aggSchema) + } + converter.convert(row, columnVectors.toArray) + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } + + /** + * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics + * information from Parquet footer file. + * + * @return A tuple of `Array[PrimitiveType]` and Array[Any]. + * The first element is the Parquet PrimitiveType of the aggregate column, + * and the second element is the aggregated value. + */ + private[sql] def getPushedDownAggResult( + footer: ParquetMetadata, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + isCaseSensitive: Boolean) + : (Array[PrimitiveType], Array[Any]) = { + val footerFileMetaData = footer.getFileMetaData + val fields = footerFileMetaData.getSchema.getFields + val blocks = footer.getBlocks + val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType] + val valuesBuilder = mutable.ArrayBuilder.make[Any] + + assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down") + aggregation.aggregateExpressions.foreach { agg => + var value: Any = None + var rowCount = 0L + var isCount = false + var index = 0 + var schemaName = "" + blocks.forEach { block => + val blockMetaData = block.getColumns + agg match { + case max: Max => + val colName = max.column.fieldNames.head + index = dataSchema.fieldNames.toList.indexOf(colName) + schemaName = "max(" + colName + ")" + val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true) + if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { + value = currentMax + } + case min: Min => + val colName = min.column.fieldNames.head + index = dataSchema.fieldNames.toList.indexOf(colName) + schemaName = "min(" + colName + ")" + val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false) + if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { + value = currentMin + } + case count: Count => + schemaName = "count(" + count.column.fieldNames.head + ")" + rowCount += block.getRowCount + var isPartitionCol = false + if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) + .toSet.contains(count.column.fieldNames.head)) { + isPartitionCol = true + } + isCount = true + if (!isPartitionCol) { + index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) + // Count(*) includes the null values, but Count(colName) doesn't. + rowCount -= getNumNulls(filePath, blockMetaData, index) + } + case _: CountStar => + schemaName = "count(*)" + rowCount += block.getRowCount + isCount = true + case _ => + } + } + if (isCount) { + valuesBuilder += rowCount + primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName); + } else { + valuesBuilder += value + val field = fields.get(index) + primitiveTypeBuilder += Types.required(field.asPrimitiveType.getPrimitiveTypeName) + .as(field.getLogicalTypeAnnotation) + .length(field.asPrimitiveType.getTypeLength) + .named(schemaName) + } + } + (primitiveTypeBuilder.result, valuesBuilder.result) + } + + /** + * Get the Max or Min value for ith column in the current block + * + * @return the Max or Min value + */ + private def getCurrentBlockMaxOrMin( + filePath: String, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int, + isMax: Boolean): Any = { + val statistics = columnChunkMetaData.get(i).getStatistics + if (!statistics.hasNonNullValue) { + throw new UnsupportedOperationException(s"No min/max found for Parquet file $filePath. " + + s"Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") + } else { + if (isMax) statistics.genericGetMax else statistics.genericGetMin + } + } + + private def getNumNulls( + filePath: String, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int): Long = { + val statistics = columnChunkMetaData.get(i).getStatistics + if (!statistics.isNumNullsSet) { + throw new UnsupportedOperationException(s"Number of nulls not set for Parquet file" + + s" $filePath. Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute" + + s" again") + } + statistics.getNumNulls; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 309f045201140..2dc4137d6f9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -96,6 +96,6 @@ abstract class FileScanBuilder( private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet - private val partitionNameSet: Set[String] = + val partitionNameSet: Set[String] = partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 058669b0937fa..111018b579ed2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -25,14 +25,16 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ @@ -53,6 +55,7 @@ import org.apache.spark.util.SerializableConfiguration * @param readDataSchema Required schema of Parquet files. * @param partitionSchema Schema of partitions. * @param filters Filters to be pushed down in the batch scan. + * @param aggregation Aggregation to be pushed down in the batch scan. * @param parquetOptions The options of Parquet datasource that are set for the read. */ case class ParquetPartitionReaderFactory( @@ -62,6 +65,7 @@ case class ParquetPartitionReaderFactory( readDataSchema: StructType, partitionSchema: StructType, filters: Array[Filter], + aggregation: Option[Aggregation], parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields) @@ -80,6 +84,30 @@ case class ParquetPartitionReaderFactory( private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + private def getFooter(file: PartitionedFile): ParquetMetadata = { + val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + if (aggregation.isEmpty) { + ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS) + } else { + // For aggregate push down, we will get max/min/count from footer statistics. + // We want to read the footer for the whole file instead of reading multiple + // footers for every split of the file. Basically if the start (the beginning of) + // the offset in PartitionedFile is 0, we will read the footer. Otherwise, it means + // that we have already read footer for that file, so we will skip reading again. + if (file.start != 0) return null + ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) + } + } + + private def getDatetimeRebaseMode( + footerFileMetaData: FileMetaData): LegacyBehaviorPolicy.Value = { + DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + } + override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && @@ -87,18 +115,44 @@ case class ParquetPartitionReaderFactory( } override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { - val reader = if (enableVectorizedReader) { - createVectorizedReader(file) - } else { - createRowBaseReader(file) - } + val fileReader = if (aggregation.isEmpty) { + val reader = if (enableVectorizedReader) { + createVectorizedReader(file) + } else { + createRowBaseReader(file) + } + + new PartitionReader[InternalRow] { + override def next(): Boolean = reader.nextKeyValue() - val fileReader = new PartitionReader[InternalRow] { - override def next(): Boolean = reader.nextKeyValue() + override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] - override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + override def close(): Unit = reader.close() + } + } else { + new PartitionReader[InternalRow] { + private var hasNext = true + private lazy val row: InternalRow = { + val footer = getFooter(file) + if (footer != null && footer.getBlocks.size > 0) { + ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema, + partitionSchema, aggregation.get, readDataSchema, + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + } else { + null + } + } + override def next(): Boolean = { + hasNext && row != null + } - override def close(): Unit = reader.close() + override def get(): InternalRow = { + hasNext = false + row + } + + override def close(): Unit = {} + } } new PartitionReaderWithPartitionValues(fileReader, readDataSchema, @@ -106,17 +160,45 @@ case class ParquetPartitionReaderFactory( } override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val vectorizedReader = createVectorizedReader(file) - vectorizedReader.enableReturningBatches() + val fileReader = if (aggregation.isEmpty) { + val vectorizedReader = createVectorizedReader(file) + vectorizedReader.enableReturningBatches() + + new PartitionReader[ColumnarBatch] { + override def next(): Boolean = vectorizedReader.nextKeyValue() - new PartitionReader[ColumnarBatch] { - override def next(): Boolean = vectorizedReader.nextKeyValue() + override def get(): ColumnarBatch = + vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] - override def get(): ColumnarBatch = - vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] + override def close(): Unit = vectorizedReader.close() + } + } else { + new PartitionReader[ColumnarBatch] { + private var hasNext = true + private val row: ColumnarBatch = { + val footer = getFooter(file) + if (footer != null && footer.getBlocks.size > 0) { + ParquetUtils.createAggColumnarBatchFromFooter(footer, file.filePath, dataSchema, + partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector, + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + } else { + null + } + } + + override def next(): Boolean = { + hasNext && row != null + } + + override def get(): ColumnarBatch = { + hasNext = false + row + } - override def close(): Unit = vectorizedReader.close() + override def close(): Unit = {} + } } + fileReader } private def buildReaderBase[T]( @@ -131,11 +213,8 @@ case class ParquetPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) - lazy val footerFileMetaData = - ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData - val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) + lazy val footerFileMetaData = getFooter(file).getFileMetaData + val datetimeRebaseMode = getDatetimeRebaseMode(footerFileMetaData) // Try to push down filters when filter push-down is enabled. val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index e277e334845c9..42dc287f73129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -24,6 +24,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} @@ -43,10 +44,17 @@ case class ParquetScan( readPartitionSchema: StructType, pushedFilters: Array[Filter], options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true + override def readSchema(): StructType = { + // If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder` + // and no need to call super.readSchema() + if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema() + } + override def createReaderFactory(): PartitionReaderFactory = { val readDataSchemaAsJson = readDataSchema.json hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) @@ -86,23 +94,46 @@ case class ParquetScan( readDataSchema, readPartitionSchema, pushedFilters, + pushedAggregate, new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) } override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => + val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) { + equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) + } else { + pushedAggregate.isEmpty && p.pushedAggregate.isEmpty + } super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) + equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { + (seqToString(pushedAggregate.get.aggregateExpressions), + seqToString(pushedAggregate.get.groupByColumns)) + } else { + ("[]", "[]") + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) + } + + private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && + a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index ff5137e928db3..01523250991cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -35,7 +37,8 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownAggregates{ lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -66,6 +69,10 @@ case class ParquetScanBuilder( parquetFilters.convertibleFilters(pushedDataFilters).toArray } + private var finalSchema = new StructType() + + private var pushedAggregations = Option.empty[Aggregation] + override protected val supportsNestedSchemaPruning: Boolean = true override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters @@ -75,8 +82,87 @@ case class ParquetScanBuilder( // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters + override def pushAggregation(aggregation: Aggregation): Boolean = { + + def getStructFieldForCol(col: NamedReference): StructField = { + schema.nameToField(col.fieldNames.head) + } + + def isPartitionCol(col: NamedReference) = { + partitionNameSet.contains(col.fieldNames.head) + } + + def processMinOrMax(agg: AggregateFunc): Boolean = { + val (column, aggType) = agg match { + case max: Max => (max.column, "max") + case min: Min => (min.column, "min") + case _ => + throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") + } + + if (isPartitionCol(column)) { + // don't push down partition column, footer doesn't have max/min for partition column + return false + } + val structField = getStructFieldForCol(column) + + structField.dataType match { + // not push down complex type + // not push down Timestamp because INT96 sort order is undefined, + // Parquet doesn't return statistics for INT96 + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => + false + case _ => + finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) + true + } + } + + if (!sparkSession.sessionState.conf.parquetAggregatePushDown || + aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) { + // Parquet footer has max/min/count for columns + // e.g. SELECT COUNT(col1) FROM t + // but footer doesn't have max/min/count for a column if max/min/count + // are combined with filter or group by + // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 + // SELECT COUNT(col1) FROM t GROUP BY col2 + // Todo: 1. add support if groupby column is partition col + // (https://issues.apache.org/jira/browse/SPARK-36646) + // 2. add support if filter col is partition col + // (https://issues.apache.org/jira/browse/SPARK-36647) + return false + } + + aggregation.groupByColumns.foreach { col => + if (col.fieldNames.length != 1) return false + finalSchema = finalSchema.add(getStructFieldForCol(col)) + } + + aggregation.aggregateExpressions.foreach { + case max: Max => + if (!processMinOrMax(max)) return false + case min: Min => + if (!processMinOrMax(min)) return false + case count: Count => + if (count.column.fieldNames.length != 1 || count.isDistinct) return false + finalSchema = + finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) + case _: CountStar => + finalSchema = finalSchema.add(StructField("count(*)", LongType)) + case _ => + return false + } + this.pushedAggregations = Some(aggregation) + true + } + override def build(): Scan = { - ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters) + // the `finalSchema` is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in readDataSchema() (in regular column pruning). These + // two are mutual exclusive. + if (pushedAggregations.isEmpty) finalSchema = readDataSchema() + ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, + readPartitionSchema(), pushedParquetFilters, options, pushedAggregations, + partitionFilters, dataFilters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 4e7fe8455ff93..11934c2f08f4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -354,7 +354,7 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala new file mode 100644 index 0000000000000..c795bd9ff3389 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -0,0 +1,518 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.sql.{Date, Timestamp} + +import org.apache.spark.SparkConf +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.functions.min +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +/** + * A test suite that tests Max/Min/Count push down. + */ +abstract class ParquetAggregatePushDownSuite + extends QueryTest + with ParquetTest + with SharedSparkSession + with ExplainSuiteHelper { + import testImplicits._ + + test("aggregate push down - nested column: Max(top level column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val max = sql("SELECT Max(_1) FROM t") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + } + } + } + + test("aggregate push down - nested column: Count(top level column) push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val count = sql("SELECT Count(_1) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(_1)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("aggregate push down - nested column: Max(nested column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val max = sql("SELECT Max(_1._2[0]) FROM t") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + } + } + } + + test("aggregate push down - nested column: Count(nested column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val count = sql("SELECT Count(_1._2[0]) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("aggregate push down - Max(partition Col): not push dow") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val max = sql("SELECT Max(p) FROM tmp") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + checkAnswer(max, Seq(Row(2))) + } + } + } + } + + test("aggregate push down - Count(partition Col): push down") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + val count = sql("SELECT COUNT(p) FROM tmp") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(p)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + } + } + + test("aggregate push down - Filter alias over aggregate") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1), MAX(_1)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(7))) + } + } + } + + test("aggregate push down - alias over aggregate") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-1, 0))) + } + } + } + + test("aggregate push down - aggregate over alias not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val df = spark.table("t") + val query = df.select($"_1".as("col1")).agg(min($"col1")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" // aggregate alias not pushed down + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(-2))) + } + } + } + + test("aggregate push down - query with group by not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // aggregate not pushed down if there is group by + val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3))) + } + } + } + + test("aggregate push down - query with filter not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // aggregate not pushed down if there is filter + val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(2))) + } + } + } + + test("aggregate push down - push down only if all the aggregates can be pushed down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // not push down since sum can't be pushed down + val selectAgg = sql("SELECT min(_1), sum(_3) FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-2, 41))) + } + } + } + + test("aggregate push down - MIN/MAX/COUNT") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + + " count(*), count(_1), count(_2), count(_3) FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_3), " + + "MAX(_3), " + + "MIN(_1), " + + "MAX(_1), " + + "COUNT(*), " + + "COUNT(_1), " + + "COUNT(_2), " + + "COUNT(_3)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + + checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) + } + } + } + + test("aggregate push down - different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + val schema = StructType(List(StructField("StringCol", StringType, true), + StructField("BooleanCol", BooleanType, false), + StructField("ByteCol", ByteType, false), + StructField("BinaryCol", BinaryType, false), + StructField("ShortCol", ShortType, false), + StructField("IntegerCol", IntegerType, true), + StructField("LongCol", LongType, false), + StructField("FloatCol", FloatType, false), + StructField("DoubleCol", DoubleType, false), + StructField("DecimalCol", DecimalType(25, 5), true), + StructField("DateCol", DateType, false), + StructField("TimestampCol", TimestampType, false)).toArray) + + val rdd = sparkContext.parallelize(rows) + withTempPath { file => + spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath) + withTempView("test") { + spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + + val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMinWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMinWithTS, expected_plan_fragment) + } + + checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) + + val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") + + testMinWithOutTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(StringCol), " + + "MIN(BooleanCol), " + + "MIN(ByteCol), " + + "MIN(BinaryCol), " + + "MIN(ShortCol), " + + "MIN(IntegerCol), " + + "MIN(LongCol), " + + "MIN(FloatCol), " + + "MIN(DoubleCol), " + + "MIN(DecimalCol), " + + "MIN(DateCol)]" + checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment) + } + + checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date))) + + val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMaxWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMaxWithTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + + val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") + + testMaxWithoutTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(StringCol), " + + "MAX(BooleanCol), " + + "MAX(ByteCol), " + + "MAX(BinaryCol), " + + "MAX(ShortCol), " + + "MAX(IntegerCol), " + + "MAX(LongCol), " + + "MAX(FloatCol), " + + "MAX(DoubleCol), " + + "MAX(DecimalCol), " + + "MAX(DateCol)]" + checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date))) + + val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + + " count(LongCol), count(FloatCol), count(DoubleCol)," + + " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test") + + testCount.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [" + + "COUNT(StringCol), " + + "COUNT(BooleanCol), " + + "COUNT(ByteCol), " + + "COUNT(BinaryCol), " + + "COUNT(ShortCol), " + + "COUNT(IntegerCol), " + + "COUNT(LongCol), " + + "COUNT(FloatCol), " + + "COUNT(DoubleCol), " + + "COUNT(DecimalCol), " + + "COUNT(DateCol), " + + "COUNT(TimestampCol)]" + checkKeywordsExistsInExplain(testCount, expected_plan_fragment) + } + + checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + } + } + } + } + } + + test("aggregate push down - column name case sensitivity") { + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(id), MIN(id)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(9, 0))) + } + } + } + } + } +} + +class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "parquet") +} + +class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") +} From 762af8332bbbeca44df3ae9b650ff78a4d1d604a Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 28 Oct 2021 17:29:15 -0700 Subject: [PATCH 27/53] [SPARK-34960][SQL] Aggregate push down for ORC ### What changes were proposed in this pull request? This PR is to add aggregate push down feature for ORC data source v2 reader. At a high level, the PR does: * The supported aggregate expression is MIN/MAX/COUNT same as [Parquet aggregate push down](https://github.com/apache/spark/pull/33639). * BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DateType are allowed in MIN/MAXX aggregate push down. All other columns types are not allowed in MIN/MAX aggregate push down. * All columns types are supported in COUNT aggregate push down. * Nested column's sub-fields are disallowed in aggregate push down. * If the file does not have valid statistics, Spark will throw exception and fail query. * If aggregate has filter or group-by column, aggregate will not be pushed down. At code level, the PR does: * `OrcScanBuilder`: `pushAggregation()` checks whether the aggregation can be pushed down. The most checking logic is shared between Parquet and ORC, extracted into `AggregatePushDownUtils.getSchemaForPushedAggregation()`. `OrcScanBuilder` will create a `OrcScan` with aggregation and aggregation data schema. * `OrcScan`: `createReaderFactory` creates a ORC reader factory with aggregation and schema. Similar change with `ParquetScan`. * `OrcPartitionReaderFactory`: `buildReaderWithAggregates` creates a ORC reader with aggregate push down (i.e. read ORC file footer to process columns statistics, instead of reading actual data in the file). `buildColumnarReaderWithAggregates` creates a columnar ORC reader similarly. Both delegate the real work to read footer in `OrcUtils.createAggInternalRowFromFooter`. * `OrcUtils.createAggInternalRowFromFooter`: reads ORC file footer to process columns statistics (real heavy lift happens here). Similar to `ParquetUtils.createAggInternalRowFromFooter`. Leverage utility method such as `OrcFooterReader.readStatistics`. * `OrcFooterReader`: `readStatistics` reads the ORC `ColumnStatistics[]` into Spark `OrcColumnStatistics`. The transformation is needed here, because ORC `ColumnStatistics[]` stores all columns statistics in a flatten array style, and hard to process. Spark `OrcColumnStatistics` stores the statistics in nested tree structure (e.g. like `StructType`). This is used by `OrcUtils.createAggInternalRowFromFooter` * `OrcColumnStatistics`: the easy-to-manipulate structure for ORC `ColumnStatistics`. This is used by `OrcFooterReader.readStatistics`. ### Why are the changes needed? To improve the performance of query with aggregate. ### Does this PR introduce _any_ user-facing change? Yes. A user-facing config `spark.sql.orc.aggregatePushdown` is added to control enabling/disabling the aggregate push down for ORC. By default the feature is disabled. ### How was this patch tested? Added unit test in `FileSourceAggregatePushDownSuite.scala`. Refactored all unit tests in https://github.com/apache/spark/pull/33639, and it now works for both Parquet and ORC. Closes #34298 from c21/orc-agg. Authored-by: Cheng Su Signed-off-by: Liang-Chi Hsieh --- .../apache/spark/sql/internal/SQLConf.scala | 10 + .../apache/spark/sql/types/StructType.scala | 2 +- .../datasources/orc/OrcColumnStatistics.java | 80 ++++ .../datasources/orc/OrcFooterReader.java | 67 +++ .../datasources/AggregatePushDownUtils.scala | 141 +++++++ .../datasources/orc/OrcDeserializer.scala | 16 + .../execution/datasources/orc/OrcUtils.scala | 145 ++++++- .../datasources/parquet/ParquetUtils.scala | 41 -- .../v2/orc/OrcPartitionReaderFactory.scala | 93 ++++- .../datasources/v2/orc/OrcScan.scala | 45 +- .../datasources/v2/orc/OrcScanBuilder.scala | 43 +- .../ParquetPartitionReaderFactory.scala | 14 +- .../datasources/v2/parquet/ParquetScan.scala | 10 +- .../v2/parquet/ParquetScanBuilder.scala | 90 +--- .../org/apache/spark/sql/FileScanSuite.scala | 2 +- ...=> FileSourceAggregatePushDownSuite.scala} | 393 +++++++++++------- 16 files changed, 870 insertions(+), 322 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/{parquet/ParquetAggregatePushDownSuite.scala => FileSourceAggregatePushDownSuite.scala} (59%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cc63aeeb5e3bf..4518b9ddfc5ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -950,6 +950,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.aggregatePushdown") + .doc("If true, aggregates will be pushed down to ORC for optimization. Support MIN, MAX and " + + "COUNT as aggregate expression. For MIN/MAX, support boolean, integer, float and date " + + "type. For COUNT, support all data types.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val ORC_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.orc.mergeSchema") .doc("When true, the Orc data source merges schemas collected from all data files, " + "otherwise the schema is picked from a random data file.") @@ -3691,6 +3699,8 @@ class SQLConf extends Serializable with Logging { def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + def orcAggregatePushDown: Boolean = getConf(ORC_AGGREGATE_PUSHDOWN_ENABLED) + def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 50b197fb9aea3..c9862cb629cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def names: Array[String] = fieldNames private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap override def equals(that: Any): Boolean = { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java new file mode 100644 index 0000000000000..8adb9e8ca20be --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; + +import java.util.ArrayList; +import java.util.List; + +/** + * Columns statistics interface wrapping ORC {@link ColumnStatistics}s. + * + * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer, + * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure, + * according to data types. The flatten array stores all data types (including nested types) in + * tree pre-ordering. This is used for aggregate push down in ORC. + * + * For nested data types (array, map and struct), the sub-field statistics are stored recursively + * inside parent column's children field. Here is an example of {@link OrcColumnStatistics}: + * + * Data schema: + * c1: int + * c2: struct + * c3: map + * c4: array + * + * OrcColumnStatistics + * | (children) + * --------------------------------------------- + * / | \ \ + * c1 c2 c3 c4 + * (integer) (struct) (map) (array) +* (min:1, | (children) | (children) | (children) + * max:10) ----- ----- element + * / \ / \ (integer) + * c2.f1 c2.f2 key value + * (integer) (float) (integer) (string) + * (min:0.1, (min:"a", + * max:100.5) max:"zzz") + */ +public class OrcColumnStatistics { + private final ColumnStatistics statistics; + private final List children; + + public OrcColumnStatistics(ColumnStatistics statistics) { + this.statistics = statistics; + this.children = new ArrayList<>(); + } + + public ColumnStatistics getStatistics() { + return statistics; + } + + public OrcColumnStatistics get(int ordinal) { + if (ordinal < 0 || ordinal >= children.size()) { + throw new IndexOutOfBoundsException( + String.format("Ordinal %d out of bounds of statistics size %d", ordinal, children.size())); + } + return children.get(ordinal); + } + + public void add(OrcColumnStatistics newChild) { + children.add(newChild); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java new file mode 100644 index 0000000000000..546b048648844 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.spark.sql.types.*; + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.Queue; + +/** + * {@link OrcFooterReader} is a util class which encapsulates the helper + * methods of reading ORC file footer. + */ +public class OrcFooterReader { + + /** + * Read the columns statistics from ORC file footer. + * + * @param orcReader the reader to read ORC file footer. + * @return Statistics for all columns in the file. + */ + public static OrcColumnStatistics readStatistics(Reader orcReader) { + TypeDescription orcSchema = orcReader.getSchema(); + ColumnStatistics[] orcStatistics = orcReader.getStatistics(); + StructType sparkSchema = OrcUtils.toCatalystSchema(orcSchema); + return convertStatistics(sparkSchema, new LinkedList<>(Arrays.asList(orcStatistics))); + } + + /** + * Convert a queue of ORC {@link ColumnStatistics}s into Spark {@link OrcColumnStatistics}. + * The queue of ORC {@link ColumnStatistics}s are assumed to be ordered as tree pre-order. + */ + private static OrcColumnStatistics convertStatistics( + DataType sparkSchema, Queue orcStatistics) { + OrcColumnStatistics statistics = new OrcColumnStatistics(orcStatistics.remove()); + if (sparkSchema instanceof StructType) { + for (StructField field : ((StructType) sparkSchema).fields()) { + statistics.add(convertStatistics(field.dataType(), orcStatistics)); + } + } else if (sparkSchema instanceof MapType) { + statistics.add(convertStatistics(((MapType) sparkSchema).keyType(), orcStatistics)); + statistics.add(convertStatistics(((MapType) sparkSchema).valueType(), orcStatistics)); + } else if (sparkSchema instanceof ArrayType) { + statistics.add(convertStatistics(((ArrayType) sparkSchema).elementType(), orcStatistics)); + } + return statistics; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala new file mode 100644 index 0000000000000..6340d97af1a04 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +/** + * Utility class for aggregate push down to Parquet and ORC. + */ +object AggregatePushDownUtils { + + /** + * Get the data schema for aggregate to be pushed down. + */ + def getSchemaForPushedAggregation( + aggregation: Aggregation, + schema: StructType, + partitionNames: Set[String], + dataFilters: Seq[Expression]): Option[StructType] = { + + var finalSchema = new StructType() + + def getStructFieldForCol(col: NamedReference): StructField = { + schema.apply(col.fieldNames.head) + } + + def isPartitionCol(col: NamedReference) = { + partitionNames.contains(col.fieldNames.head) + } + + def processMinOrMax(agg: AggregateFunc): Boolean = { + val (column, aggType) = agg match { + case max: Max => (max.column, "max") + case min: Min => (min.column, "min") + case _ => + throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") + } + + if (isPartitionCol(column)) { + // don't push down partition column, footer doesn't have max/min for partition column + return false + } + val structField = getStructFieldForCol(column) + + structField.dataType match { + // not push down complex type + // not push down Timestamp because INT96 sort order is undefined, + // Parquet doesn't return statistics for INT96 + // not push down Parquet Binary because min/max could be truncated + // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary + // could be Spark StringType, BinaryType or DecimalType. + // not push down for ORC with same reason. + case BooleanType | ByteType | ShortType | IntegerType + | LongType | FloatType | DoubleType | DateType => + finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) + true + case _ => + false + } + } + + if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) { + // Parquet/ORC footer has max/min/count for columns + // e.g. SELECT COUNT(col1) FROM t + // but footer doesn't have max/min/count for a column if max/min/count + // are combined with filter or group by + // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 + // SELECT COUNT(col1) FROM t GROUP BY col2 + // However, if the filter is on partition column, max/min/count can still be pushed down + // Todo: add support if groupby column is partition col + // (https://issues.apache.org/jira/browse/SPARK-36646) + return None + } + + aggregation.aggregateExpressions.foreach { + case max: Max => + if (!processMinOrMax(max)) return None + case min: Min => + if (!processMinOrMax(min)) return None + case count: Count => + if (count.column.fieldNames.length != 1 || count.isDistinct) return None + finalSchema = + finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) + case _: CountStar => + finalSchema = finalSchema.add(StructField("count(*)", LongType)) + case _ => + return None + } + + Some(finalSchema) + } + + /** + * Check if two Aggregation `a` and `b` is equal or not. + */ + def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && + a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) + } + + /** + * Convert the aggregates result from `InternalRow` to `ColumnarBatch`. + * This is used for columnar reader. + */ + def convertAggregatesRowToBatch( + aggregatesAsRow: InternalRow, + aggregatesSchema: StructType, + offHeap: Boolean): ColumnarBatch = { + val converter = new RowToColumnConverter(aggregatesSchema) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(1, aggregatesSchema) + } else { + OnHeapColumnVector.allocateColumns(1, aggregatesSchema) + } + converter.convert(aggregatesAsRow, columnVectors.toArray) + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index fa8977f239164..59a52b318622b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -68,6 +68,22 @@ class OrcDeserializer( resultRow } + def deserializeFromValues(orcValues: Seq[WritableComparable[_]]): InternalRow = { + var targetColumnIndex = 0 + while (targetColumnIndex < fieldWriters.length) { + if (fieldWriters(targetColumnIndex) != null) { + val value = orcValues(requestedColIds(targetColumnIndex)) + if (value == null) { + resultRow.setNullAt(targetColumnIndex) + } else { + fieldWriters(targetColumnIndex)(value) + } + } + targetColumnIndex += 1 + } + resultRow + } + /** * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index a8647726fe022..4cf13a90ce829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -24,15 +24,19 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription, Writer} +import org.apache.hadoop.hive.serde2.io.DateWritable +import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable} +import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer} -import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.SchemaMergeUtils import org.apache.spark.sql.types._ @@ -84,7 +88,7 @@ object OrcUtils extends Logging { } } - private def toCatalystSchema(schema: TypeDescription): StructType = { + def toCatalystSchema(schema: TypeDescription): StructType = { // The Spark query engine has not completely supported CHAR/VARCHAR type yet, and here we // replace the orc CHAR/VARCHAR with STRING type. CharVarcharUtils.replaceCharVarcharWithStringInSchema( @@ -259,4 +263,139 @@ object OrcUtils extends Logging { OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString) resultSchemaString } + + /** + * Checks if `dataType` supports columnar reads. + * + * @param dataType Data type of the orc files. + * @param nestedColumnEnabled True if columnar reads is enabled for nested column types. + * @return Returns true if data type supports columnar reads. + */ + def supportColumnarReads( + dataType: DataType, + nestedColumnEnabled: Boolean): Boolean = { + dataType match { + case _: AtomicType => true + case st: StructType if nestedColumnEnabled => + st.forall(f => supportColumnarReads(f.dataType, nestedColumnEnabled)) + case ArrayType(elementType, _) if nestedColumnEnabled => + supportColumnarReads(elementType, nestedColumnEnabled) + case MapType(keyType, valueType, _) if nestedColumnEnabled => + supportColumnarReads(keyType, nestedColumnEnabled) && + supportColumnarReads(valueType, nestedColumnEnabled) + case _ => false + } + } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + var columnsStatistics: OrcColumnStatistics = null + try { + columnsStatistics = OrcFooterReader.readStatistics(reader) + } catch { case e: Exception => + throw new SparkException( + s"Cannot read columns statistics in file: $filePath. Please consider disabling " + + s"ORC aggregate push down by setting 'spark.sql.orc.aggregatePushdown' to false.", e) + } + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + // Return null if number of non-null values is zero. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + if (statistics.getNumberOfValues == 0) { + return null + } + + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take type $dataType " + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take type $dataType " + + "for DoubleColumnStatistics") + } + case s: DateColumnStatistics => + new DateWritable( + if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"$statistics as the ORC column statistics") + } + } + + val aggORCValues: Seq[WritableComparable[_]] = + aggregation.aggregateExpressions.zipWithIndex.map { + case (max: Max, index) => + val columnName = max.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema(index).dataType + getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) + case (min: Min, index) => + val columnName = min.column.fieldNames.head + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema.apply(index).dataType + getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) + case (count: Count, _) => + val columnName = count.column.fieldNames.head + val isPartitionColumn = partitionSchema.fields.map(_.name).contains(columnName) + // NOTE: Count(columnName) doesn't include null values. + // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values + // for ColumnStatistics of individual column. In addition to this, ORC also stores number + // of all values (null and non-null) separately. + val nonNullRowsCount = if (isPartitionColumn) { + columnsStatistics.getStatistics.getNumberOfValues + } else { + getColumnStatistics(columnName).getNumberOfValues + } + new LongWritable(nonNullRowsCount) + case (_: CountStar, _) => + // Count(*) includes both null and non-null values. + new LongWritable(columnsStatistics.getStatistics.getNumberOfValues) + case (x, _) => + throw new IllegalArgumentException( + s"createAggInternalRowFromFooter should not take $x as the aggregate expression") + } + + val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray) + orcValuesDeserializer.deserializeFromValues(aggORCValues) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 1093f9c5aa51b..0e4b9283d4866 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -32,12 +32,9 @@ import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} -import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.PartitioningUtils -import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object ParquetUtils { def inferSchema( @@ -201,44 +198,6 @@ object ParquetUtils { converter.currentRecord } - /** - * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of - * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader - * to read data from Parquet and aggregate at Spark layer. Instead we want - * to get the aggregates (Max/Min/Count) result using the statistics information - * from Parquet footer file, and then construct a ColumnarBatch from these aggregate results. - * - * @return Aggregate results in the format of ColumnarBatch - */ - private[sql] def createAggColumnarBatchFromFooter( - footer: ParquetMetadata, - filePath: String, - dataSchema: StructType, - partitionSchema: StructType, - aggregation: Aggregation, - aggSchema: StructType, - offHeap: Boolean, - datetimeRebaseMode: LegacyBehaviorPolicy.Value, - isCaseSensitive: Boolean): ColumnarBatch = { - val row = createAggInternalRowFromFooter( - footer, - filePath, - dataSchema, - partitionSchema, - aggregation, - aggSchema, - datetimeRebaseMode, - isCaseSensitive) - val converter = new RowToColumnConverter(aggSchema) - val columnVectors = if (offHeap) { - OffHeapColumnVector.allocateColumns(1, aggSchema) - } else { - OnHeapColumnVector.allocateColumns(1, aggSchema) - } - converter.convert(row, columnVectors.toArray) - new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) - } - /** * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics * information from Parquet footer file. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 414252cc12481..79c34827c0bec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -23,14 +23,15 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.orc.{OrcConf, OrcFile, TypeDescription} +import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription} import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} -import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitionedFile} import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -54,7 +55,8 @@ case class OrcPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, - filters: Array[Filter]) extends FilePartitionReaderFactory { + filters: Array[Filter], + aggregation: Option[Aggregation]) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -79,17 +81,14 @@ case class OrcPartitionReaderFactory( override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value - - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) - val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) + if (aggregation.nonEmpty) { + return buildReaderWithAggregates(filePath, conf) + } - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, reader, conf) } @@ -126,17 +125,14 @@ case class OrcPartitionReaderFactory( override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { val conf = broadcastedConf.value.value - - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) - val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) + if (aggregation.nonEmpty) { + return buildColumnarReaderWithAggregates(filePath, conf) + } - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, reader, conf) } @@ -171,4 +167,67 @@ case class OrcPartitionReaderFactory( } } + private def createORCReader(filePath: Path, conf: Configuration): Reader = { + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) + + pushDownPredicates(filePath, conf) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + OrcFile.createReader(filePath, readerOptions) + } + + /** + * Build reader with aggregate push down. + */ + private def buildReaderWithAggregates( + filePath: Path, + conf: Configuration): PartitionReader[InternalRow] = { + new PartitionReader[InternalRow] { + private var hasNext = true + private lazy val row: InternalRow = { + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => + OrcUtils.createAggInternalRowFromFooter( + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema) + } + } + + override def next(): Boolean = hasNext + + override def get(): InternalRow = { + hasNext = false + row + } + + override def close(): Unit = {} + } + } + + /** + * Build columnar reader with aggregate push down. + */ + private def buildColumnarReaderWithAggregates( + filePath: Path, + conf: Configuration): PartitionReader[ColumnarBatch] = { + new PartitionReader[ColumnarBatch] { + private var hasNext = true + private lazy val batch: ColumnarBatch = { + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => + val row = OrcUtils.createAggInternalRowFromFooter( + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, + readDataSchema) + AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) + } + } + + override def next(): Boolean = hasNext + + override def get(): ColumnarBatch = { + hasNext = false + batch + } + + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 7619e3c503139..6b9d181a7f4c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -21,8 +21,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -37,10 +38,25 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { - override def isSplitable(path: Path): Boolean = true + override def isSplitable(path: Path): Boolean = { + // If aggregate is pushed down, only the file footer will be read once, + // so file should be not split across multiple tasks. + pushedAggregate.isEmpty + } + + override def readSchema(): StructType = { + // If aggregate is pushed down, schema has already been pruned in `OrcScanBuilder` + // and no need to call super.readSchema() + if (pushedAggregate.nonEmpty) { + readDataSchema + } else { + super.readSchema() + } + } override def createReaderFactory(): PartitionReaderFactory = { val broadcastedConf = sparkSession.sparkContext.broadcast( @@ -48,24 +64,39 @@ case class OrcScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, pushedFilters) + dataSchema, readDataSchema, readPartitionSchema, pushedFilters, pushedAggregate) } override def equals(obj: Any): Boolean = obj match { case o: OrcScan => + val pushedDownAggEqual = if (pushedAggregate.nonEmpty && o.pushedAggregate.nonEmpty) { + AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, o.pushedAggregate.get) + } else { + pushedAggregate.isEmpty && o.pushedAggregate.isEmpty + } super.equals(o) && dataSchema == o.dataSchema && options == o.options && - equivalentFilters(pushedFilters, o.pushedFilters) - + equivalentFilters(pushedFilters, o.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { + (seqToString(pushedAggregate.get.aggregateExpressions), + seqToString(pushedAggregate.get.groupByColumns)) + } else { + ("[]", "[]") + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index cfa396f5482f4..d2c17fda4a382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2.orc import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf @@ -35,18 +36,31 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownAggregates { + lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) } + private var finalSchema = new StructType() + + private var pushedAggregations = Option.empty[Aggregation] + override protected val supportsNestedSchemaPruning: Boolean = true override def build(): Scan = { - OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), options, pushedDataFilters, partitionFilters, dataFilters) + // the `finalSchema` is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in readDataSchema() (in regular column pruning). These + // two are mutual exclusive. + if (pushedAggregations.isEmpty) { + finalSchema = readDataSchema() + } + OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, + readPartitionSchema(), options, pushedAggregations, pushedDataFilters, partitionFilters, + dataFilters) } override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { @@ -58,4 +72,23 @@ case class OrcScanBuilder( Array.empty[Filter] } } + + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!sparkSession.sessionState.conf.orcAggregatePushDown) { + return false + } + + AggregatePushDownUtils.getSchemaForPushedAggregation( + aggregation, + schema, + partitionNameSet, + dataFilters) match { + + case Some(schema) => + finalSchema = schema + this.pushedAggregations = Some(aggregation) + true + case _ => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 111018b579ed2..6f021ff2e97f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} -import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -175,24 +175,26 @@ case class ParquetPartitionReaderFactory( } else { new PartitionReader[ColumnarBatch] { private var hasNext = true - private val row: ColumnarBatch = { + private val batch: ColumnarBatch = { val footer = getFooter(file) if (footer != null && footer.getBlocks.size > 0) { - ParquetUtils.createAggColumnarBatchFromFooter(footer, file.filePath, dataSchema, - partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector, + val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, + dataSchema, partitionSchema, aggregation.get, readDataSchema, getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + AggregatePushDownUtils.convertAggregatesRowToBatch( + row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined) } else { null } } override def next(): Boolean = { - hasNext && row != null + hasNext && batch != null } override def get(): ColumnarBatch = { hasNext = false - row + batch } override def close(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 42dc287f73129..b92ed82190ae8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf @@ -101,7 +101,7 @@ case class ParquetScan( override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) { - equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) + AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) } else { pushedAggregate.isEmpty && p.pushedAggregate.isEmpty } @@ -130,10 +130,4 @@ case class ParquetScan( Map("PushedAggregation" -> pushedAggregationsStr) ++ Map("PushedGroupBy" -> pushedGroupByStr) } - - private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { - a.aggregateExpressions.sortBy(_.hashCode()) - .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && - a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 01523250991cf..d198321eacdb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.expressions.NamedReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -83,84 +82,31 @@ case class ParquetScanBuilder( override def pushedFilters(): Array[Filter] = pushedParquetFilters override def pushAggregation(aggregation: Aggregation): Boolean = { - - def getStructFieldForCol(col: NamedReference): StructField = { - schema.nameToField(col.fieldNames.head) - } - - def isPartitionCol(col: NamedReference) = { - partitionNameSet.contains(col.fieldNames.head) - } - - def processMinOrMax(agg: AggregateFunc): Boolean = { - val (column, aggType) = agg match { - case max: Max => (max.column, "max") - case min: Min => (min.column, "min") - case _ => - throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") - } - - if (isPartitionCol(column)) { - // don't push down partition column, footer doesn't have max/min for partition column - return false - } - val structField = getStructFieldForCol(column) - - structField.dataType match { - // not push down complex type - // not push down Timestamp because INT96 sort order is undefined, - // Parquet doesn't return statistics for INT96 - case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => - false - case _ => - finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) - true - } - } - - if (!sparkSession.sessionState.conf.parquetAggregatePushDown || - aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) { - // Parquet footer has max/min/count for columns - // e.g. SELECT COUNT(col1) FROM t - // but footer doesn't have max/min/count for a column if max/min/count - // are combined with filter or group by - // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 - // SELECT COUNT(col1) FROM t GROUP BY col2 - // Todo: 1. add support if groupby column is partition col - // (https://issues.apache.org/jira/browse/SPARK-36646) - // 2. add support if filter col is partition col - // (https://issues.apache.org/jira/browse/SPARK-36647) + if (!sparkSession.sessionState.conf.parquetAggregatePushDown) { return false } - aggregation.groupByColumns.foreach { col => - if (col.fieldNames.length != 1) return false - finalSchema = finalSchema.add(getStructFieldForCol(col)) + AggregatePushDownUtils.getSchemaForPushedAggregation( + aggregation, + schema, + partitionNameSet, + dataFilters) match { + + case Some(schema) => + finalSchema = schema + this.pushedAggregations = Some(aggregation) + true + case _ => false } - - aggregation.aggregateExpressions.foreach { - case max: Max => - if (!processMinOrMax(max)) return false - case min: Min => - if (!processMinOrMax(min)) return false - case count: Count => - if (count.column.fieldNames.length != 1 || count.isDistinct) return false - finalSchema = - finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) - case _: CountStar => - finalSchema = finalSchema.add(StructField("count(*)", LongType)) - case _ => - return false - } - this.pushedAggregations = Some(aggregation) - true } override def build(): Scan = { // the `finalSchema` is either pruned in pushAggregation (if aggregates are // pushed down), or pruned in readDataSchema() (in regular column pruning). These // two are mutual exclusive. - if (pushedAggregations.isEmpty) finalSchema = readDataSchema() + if (pushedAggregations.isEmpty) { + finalSchema = readDataSchema() + } ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, readPartitionSchema(), pushedParquetFilters, options, pushedAggregations, partitionFilters, dataFilters) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 11934c2f08f4c..e213b32fdd243 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -358,7 +358,7 @@ class FileScanSuite extends FileScanSuiteBase { Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => - OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df), + OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, None, f, pf, df), Seq.empty), ("CSVScan", (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala similarity index 59% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index c795bd9ff3389..a3d01e483209a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -15,33 +15,39 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet +package org.apache.spark.sql.execution.datasources import java.sql.{Date, Timestamp} import org.apache.spark.SparkConf -import org.apache.spark.sql._ +import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.execution.datasources.orc.OrcTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.functions.min import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType, TimestampType} /** - * A test suite that tests Max/Min/Count push down. + * A test suite that tests aggregate push down for Parquet and ORC. */ -abstract class ParquetAggregatePushDownSuite +trait FileSourceAggregatePushDownSuite extends QueryTest - with ParquetTest + with FileBasedDataSourceTest with SharedSparkSession with ExplainSuiteHelper { + import testImplicits._ - test("aggregate push down - nested column: Max(top level column) not push down") { + protected def format: String + // The SQL config key for enabling aggregate push down. + protected val aggPushDownEnabledKey: String + + test("nested column: Max(top level column) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val max = sql("SELECT Max(_1) FROM t") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -53,11 +59,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Count(top level column) push down") { + test("nested column: Count(top level column) push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val count = sql("SELECT Count(_1) FROM t") count.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -70,11 +75,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Max(nested column) not push down") { + test("nested column: Max(nested sub-field) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey-> "true") { + withDataSourceTable(data, "t") { val max = sql("SELECT Max(_1._2[0]) FROM t") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -86,11 +90,10 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - nested column: Count(nested column) not push down") { + test("nested column: Count(nested sub-field) not push down") { val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { - withParquetTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { val count = sql("SELECT Count(_1._2[0]) FROM t") count.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -103,13 +106,13 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Max(partition Col): not push dow") { + test("Max(partition column): not push down") { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + withSQLConf(aggPushDownEnabledKey -> "true") { val max = sql("SELECT Max(p) FROM tmp") max.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -123,15 +126,15 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Count(partition Col): push down") { + test("Count(partition column): push down") { withTempPath { dir => - spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + spark.range(10).selectExpr("if(id % 2 = 0, null, id) AS n", "id % 3 as p") + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") val enableVectorizedReader = Seq("false", "true") for (testVectorizedReader <- enableVectorizedReader) { - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + withSQLConf(aggPushDownEnabledKey -> "true", vectorizedReaderEnabledKey -> testVectorizedReader) { val count = sql("SELECT COUNT(p) FROM tmp") count.queryExecution.optimizedPlan.collect { @@ -147,12 +150,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - Filter alias over aggregate") { + test("filter alias over aggregate") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -165,12 +167,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - alias over aggregate") { + test("alias over aggregate") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -183,12 +184,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - aggregate over alias not push down") { + test("aggregate over alias not push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val df = spark.table("t") val query = df.select($"_1".as("col1")).agg(min($"col1")) query.queryExecution.optimizedPlan.collect { @@ -202,12 +202,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - query with group by not push down") { + test("query with group by not push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // aggregate not pushed down if there is group by val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ") selectAgg.queryExecution.optimizedPlan.collect { @@ -221,12 +220,11 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - query with filter not push down") { + test("aggregate with data filter cannot be pushed down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // aggregate not pushed down if there is filter val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0") selectAgg.queryExecution.optimizedPlan.collect { @@ -240,12 +238,34 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - push down only if all the aggregates can be pushed down") { + test("aggregate with partition filter can be pushed down") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + Seq("false", "true").foreach { enableVectorizedReader => + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> enableVectorizedReader) { + val max = sql("SELECT max(id), min(id), count(id) FROM tmp WHERE p = 0") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(id), MIN(id), COUNT(id)]" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + checkAnswer(max, Seq(Row(9, 0, 4))) + } + } + } + } + } + + test("push down only if all the aggregates can be pushed down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 7)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { // not push down since sum can't be pushed down val selectAgg = sql("SELECT min(_1), sum(_3) FROM t") selectAgg.queryExecution.optimizedPlan.collect { @@ -262,9 +282,8 @@ abstract class ParquetAggregatePushDownSuite test("aggregate push down - MIN/MAX/COUNT") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) - withParquetTable(data, "t") { - withSQLConf( - SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + " count(*), count(_1), count(_2), count(_3) FROM t") selectAgg.queryExecution.optimizedPlan.collect { @@ -286,7 +305,13 @@ abstract class ParquetAggregatePushDownSuite } } - test("aggregate push down - different data types") { + private def testPushDownForAllDataTypes( + inputRows: Seq[Row], + expectedMinWithAllTypes: Seq[Row], + expectedMinWithOutTSAndBinary: Seq[Row], + expectedMaxWithAllTypes: Seq[Row], + expectedMaxWithOutTSAndBinary: Seq[Row], + expectedCount: Seq[Row]): Unit = { implicit class StringToDate(s: String) { def date: Date = Date.valueOf(s) } @@ -295,49 +320,6 @@ abstract class ParquetAggregatePushDownSuite def ts: Timestamp = Timestamp.valueOf(s) } - val rows = - Seq( - Row( - "a string", - true, - 10.toByte, - "Spark SQL".getBytes, - 12.toShort, - 3, - Long.MaxValue, - 0.15.toFloat, - 0.75D, - Decimal("12.345678"), - ("2021-01-01").date, - ("2015-01-01 23:50:59.123").ts), - Row( - "test string", - false, - 1.toByte, - "Parquet".getBytes, - 2.toShort, - null, - Long.MinValue, - 0.25.toFloat, - 0.85D, - Decimal("1.2345678"), - ("2015-01-01").date, - ("2021-01-01 23:50:59.123").ts), - Row( - null, - true, - 10000.toByte, - "Spark ML".getBytes, - 222.toShort, - 113, - 11111111L, - 0.25.toFloat, - 0.75D, - Decimal("12345.678"), - ("2004-06-19").date, - ("1999-08-26 10:43:59.123").ts) - ) - val schema = StructType(List(StructField("StringCol", StringType, true), StructField("BooleanCol", BooleanType, false), StructField("ByteCol", ByteType, false), @@ -351,99 +333,91 @@ abstract class ParquetAggregatePushDownSuite StructField("DateCol", DateType, false), StructField("TimestampCol", TimestampType, false)).toArray) - val rdd = sparkContext.parallelize(rows) + val rdd = sparkContext.parallelize(inputRows) withTempPath { file => - spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath) + spark.createDataFrame(rdd, schema).write.format(format).save(file.getCanonicalPath) withTempView("test") { - spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test") - val enableVectorizedReader = Seq("false", "true") - for (testVectorizedReader <- enableVectorizedReader) { - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", - vectorizedReaderEnabledKey -> testVectorizedReader) { + spark.read.format(format).load(file.getCanonicalPath).createOrReplaceTempView("test") + Seq("false", "true").foreach { enableVectorizedReader => + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> enableVectorizedReader) { - val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + val testMinWithAllTypes = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down - testMinWithTS.queryExecution.optimizedPlan.collect { + // In addition, Parquet Binary min/max could be truncated, so we disable aggregate + // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType). + // Also do not push down for ORC with same reason. + testMinWithAllTypes.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregation: []" - checkKeywordsExistsInExplain(testMinWithTS, expected_plan_fragment) + checkKeywordsExistsInExplain(testMinWithAllTypes, expected_plan_fragment) } - checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, - 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, - ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) + checkAnswer(testMinWithAllTypes, expectedMinWithAllTypes) - val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + - "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + - "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") + val testMinWithOutTSAndBinary = sql("SELECT min(BooleanCol), min(ByteCol), " + + "min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DateCol) FROM test") - testMinWithOutTS.queryExecution.optimizedPlan.collect { + testMinWithOutTSAndBinary.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [MIN(StringCol), " + - "MIN(BooleanCol), " + + "PushedAggregation: [MIN(BooleanCol), " + "MIN(ByteCol), " + - "MIN(BinaryCol), " + "MIN(ShortCol), " + "MIN(IntegerCol), " + "MIN(LongCol), " + "MIN(FloatCol), " + "MIN(DoubleCol), " + - "MIN(DecimalCol), " + "MIN(DateCol)]" - checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment) + checkKeywordsExistsInExplain(testMinWithOutTSAndBinary, expected_plan_fragment) } - checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, - 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, - ("2004-06-19").date))) + checkAnswer(testMinWithOutTSAndBinary, expectedMinWithOutTSAndBinary) - val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + - "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + - "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test") + val testMaxWithAllTypes = sql("SELECT max(StringCol), max(BooleanCol), " + + "max(ByteCol), max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), " + + "max(FloatCol), max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) " + + "FROM test") // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down - testMaxWithTS.queryExecution.optimizedPlan.collect { + // In addition, Parquet Binary min/max could be truncated, so we disable aggregate + // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType). + // Also do not push down for ORC with same reason. + testMaxWithAllTypes.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregation: []" - checkKeywordsExistsInExplain(testMaxWithTS, expected_plan_fragment) + checkKeywordsExistsInExplain(testMaxWithAllTypes, expected_plan_fragment) } - checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte, - "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, - 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + checkAnswer(testMaxWithAllTypes, expectedMaxWithAllTypes) - val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + - "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + - "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") + val testMaxWithoutTSAndBinary = sql("SELECT max(BooleanCol), max(ByteCol), " + + "max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DateCol) FROM test") - testMaxWithoutTS.queryExecution.optimizedPlan.collect { + testMaxWithoutTSAndBinary.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [MAX(StringCol), " + - "MAX(BooleanCol), " + + "PushedAggregation: [MAX(BooleanCol), " + "MAX(ByteCol), " + - "MAX(BinaryCol), " + "MAX(ShortCol), " + "MAX(IntegerCol), " + "MAX(LongCol), " + "MAX(FloatCol), " + "MAX(DoubleCol), " + - "MAX(DecimalCol), " + "MAX(DateCol)]" - checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment) + checkKeywordsExistsInExplain(testMaxWithoutTSAndBinary, expected_plan_fragment) } - checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte, - "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, - 12345.678, ("2021-01-01").date))) + checkAnswer(testMaxWithoutTSAndBinary, expectedMaxWithOutTSAndBinary) val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + @@ -469,23 +443,97 @@ abstract class ParquetAggregatePushDownSuite checkKeywordsExistsInExplain(testCount, expected_plan_fragment) } - checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + checkAnswer(testCount, expectedCount) } } } } } - test("aggregate push down - column name case sensitivity") { - val enableVectorizedReader = Seq("false", "true") - for (testVectorizedReader <- enableVectorizedReader) { - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", - vectorizedReaderEnabledKey -> testVectorizedReader) { + test("aggregate push down - different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + testPushDownForAllDataTypes( + rows, + Seq(Row("a string", false, 1.toByte, + "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, + 1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)), + Seq(Row(false, 1.toByte, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date)), + Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)), + Seq(Row(true, 16.toByte, + 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, ("2021-01-01").date)), + Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3)) + ) + + // Test for 0 row (empty file) + val nullRow = Row.fromSeq((1 to 12).map(_ => null)) + val nullRowWithOutTSAndBinary = Row.fromSeq((1 to 8).map(_ => null)) + val zeroCount = Row.fromSeq((1 to 12).map(_ => 0)) + testPushDownForAllDataTypes(Seq.empty, Seq(nullRow), Seq(nullRowWithOutTSAndBinary), + Seq(nullRow), Seq(nullRowWithOutTSAndBinary), Seq(zeroCount)) + } + + test("column name case sensitivity") { + Seq("false", "true").foreach { enableVectorizedReader => + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> enableVectorizedReader) { withTempPath { dir => spark.range(10).selectExpr("id", "id % 3 as p") - .write.partitionBy("p").parquet(dir.getCanonicalPath) + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) withTempView("tmp") { - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp") selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -501,18 +549,41 @@ abstract class ParquetAggregatePushDownSuite } } +abstract class ParquetAggregatePushDownSuite + extends FileSourceAggregatePushDownSuite with ParquetTest { + + override def format: String = "parquet" + override protected val aggPushDownEnabledKey: String = + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key +} + class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { override protected def sparkConf: SparkConf = - super - .sparkConf - .set(SQLConf.USE_V1_SOURCE_LIST, "parquet") + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet") } class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite { override protected def sparkConf: SparkConf = - super - .sparkConf - .set(SQLConf.USE_V1_SOURCE_LIST, "") + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") +} + +abstract class OrcAggregatePushDownSuite extends OrcTest with FileSourceAggregatePushDownSuite { + + override def format: String = "orc" + override protected val aggPushDownEnabledKey: String = + SQLConf.ORC_AGGREGATE_PUSHDOWN_ENABLED.key +} + +class OrcV1AggregatePushDownSuite extends OrcAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "orc") +} + +class OrcV2AggregatePushDownSuite extends OrcAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") } From 4c2380ba67a7d3329c044cdffec6cebbf77291d4 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 10 Feb 2022 21:25:53 +0800 Subject: [PATCH 28/53] [SPARK-37960][SQL] A new framework to represent catalyst expressions in DS v2 APIs ### What changes were proposed in this pull request? This PR provides a new framework to represent catalyst expressions in DS v2 APIs. `GeneralSQLExpression` is a general SQL expression to represent catalyst expression in DS v2 API. `ExpressionSQLBuilder` is a builder to generate `GeneralSQLExpression` from catalyst expressions. `CASE ... WHEN ... ELSE ... END` is just the first use case. This PR also supports aggregate push down with `CASE ... WHEN ... ELSE ... END`. ### Why are the changes needed? Support aggregate push down with `CASE ... WHEN ... ELSE ... END`. ### Does this PR introduce _any_ user-facing change? Yes. Users could use `CASE ... WHEN ... ELSE ... END` with aggregate push down. ### How was this patch tested? New tests. Closes #35248 from beliefer/SPARK-37960. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../expressions/GeneralSQLExpression.java | 41 +++++++++++ .../connector/expressions/aggregate/Avg.java | 14 ++-- .../expressions/aggregate/Count.java | 14 ++-- .../connector/expressions/aggregate/Max.java | 10 +-- .../connector/expressions/aggregate/Min.java | 10 +-- .../connector/expressions/aggregate/Sum.java | 14 ++-- .../catalyst/util/ExpressionSQLBuilder.scala | 69 +++++++++++++++++++ .../datasources/AggregatePushDownUtils.scala | 39 ++++++----- .../datasources/DataSourceStrategy.scala | 57 ++++++++------- .../execution/datasources/orc/OrcUtils.scala | 13 ++-- .../datasources/parquet/ParquetUtils.scala | 20 +++--- .../datasources/v2/V2ColumnUtils.scala | 27 ++++++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 55 +++++++++++---- .../FileSourceAggregatePushDownSuite.scala | 28 ++++++++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 67 ++++++++++++++++-- 15 files changed, 371 insertions(+), 107 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java new file mode 100644 index 0000000000000..ebeee22a853cf --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; + +import org.apache.spark.annotation.Evolving; + +/** + * The general SQL string corresponding to expression. + * + * @since 3.3.0 + */ +@Evolving +public class GeneralSQLExpression implements Expression, Serializable { + private String sql; + + public GeneralSQLExpression(String sql) { + this.sql = sql; + } + + public String sql() { return sql; } + + @Override + public String toString() { return sql; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java index 5e10ec9ee1644..cc9d27ab8e59c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the mean of all the values in a group. @@ -27,23 +27,23 @@ */ @Evolving public final class Avg implements AggregateFunc { - private final NamedReference column; + private final Expression input; private final boolean isDistinct; - public Avg(NamedReference column, boolean isDistinct) { - this.column = column; + public Avg(Expression column, boolean isDistinct) { + this.input = column; this.isDistinct = isDistinct; } - public NamedReference column() { return column; } + public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } @Override public String toString() { if (isDistinct) { - return "AVG(DISTINCT " + column.describe() + ")"; + return "AVG(DISTINCT " + input.describe() + ")"; } else { - return "AVG(" + column.describe() + ")"; + return "AVG(" + input.describe() + ")"; } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java index 1685770604a46..54c64b83c5d52 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the number of the specific row in a group. @@ -27,23 +27,23 @@ */ @Evolving public final class Count implements AggregateFunc { - private final NamedReference column; + private final Expression input; private final boolean isDistinct; - public Count(NamedReference column, boolean isDistinct) { - this.column = column; + public Count(Expression column, boolean isDistinct) { + this.input = column; this.isDistinct = isDistinct; } - public NamedReference column() { return column; } + public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } @Override public String toString() { if (isDistinct) { - return "COUNT(DISTINCT " + column.describe() + ")"; + return "COUNT(DISTINCT " + input.describe() + ")"; } else { - return "COUNT(" + column.describe() + ")"; + return "COUNT(" + input.describe() + ")"; } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java index 5acdf14bf7e2f..971aac279e09b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the maximum value in a group. @@ -27,12 +27,12 @@ */ @Evolving public final class Max implements AggregateFunc { - private final NamedReference column; + private final Expression input; - public Max(NamedReference column) { this.column = column; } + public Max(Expression column) { this.input = column; } - public NamedReference column() { return column; } + public Expression column() { return input; } @Override - public String toString() { return "MAX(" + column.describe() + ")"; } + public String toString() { return "MAX(" + input.describe() + ")"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java index 824c607ea7df0..8d0644b0f0103 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the minimum value in a group. @@ -27,12 +27,12 @@ */ @Evolving public final class Min implements AggregateFunc { - private final NamedReference column; + private final Expression input; - public Min(NamedReference column) { this.column = column; } + public Min(Expression column) { this.input = column; } - public NamedReference column() { return column; } + public Expression column() { return input; } @Override - public String toString() { return "MIN(" + column.describe() + ")"; } + public String toString() { return "MIN(" + input.describe() + ")"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java index 6b04dc38c2846..721ef31c9a817 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the summation of all the values in a group. @@ -27,23 +27,23 @@ */ @Evolving public final class Sum implements AggregateFunc { - private final NamedReference column; + private final Expression input; private final boolean isDistinct; - public Sum(NamedReference column, boolean isDistinct) { - this.column = column; + public Sum(Expression column, boolean isDistinct) { + this.input = column; this.isDistinct = isDistinct; } - public NamedReference column() { return column; } + public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } @Override public String toString() { if (isDistinct) { - return "SUM(DISTINCT " + column.describe() + ")"; + return "SUM(DISTINCT " + input.describe() + ")"; } else { - return "SUM(" + column.describe() + ")"; + return "SUM(" + input.describe() + ")"; } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala new file mode 100644 index 0000000000000..6239d0e2e7ae8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryOperator, CaseWhen, EqualTo, Expression, IsNotNull, IsNull, Literal, Not} +import org.apache.spark.sql.connector.expressions.LiteralValue + +/** + * The builder to generate SQL string from catalyst expressions. + */ +class ExpressionSQLBuilder(e: Expression) { + + def build(): Option[String] = generateSQL(e) + + private def generateSQL(expr: Expression): Option[String] = expr match { + case Literal(value, dataType) => Some(LiteralValue(value, dataType).toString) + case a: Attribute => Some(quoteIfNeeded(a.name)) + case IsNull(col) => generateSQL(col).map(c => s"$c IS NULL") + case IsNotNull(col) => generateSQL(col).map(c => s"$c IS NOT NULL") + case b: BinaryOperator => + val l = generateSQL(b.left) + val r = generateSQL(b.right) + if (l.isDefined && r.isDefined) { + Some(s"(${l.get}) ${b.sqlOperator} (${r.get})") + } else { + None + } + case Not(EqualTo(left, right)) => + val l = generateSQL(left) + val r = generateSQL(right) + if (l.isDefined && r.isDefined) { + Some(s"${l.get} != ${r.get}") + } else { + None + } + case Not(child) => generateSQL(child).map(v => s"NOT ($v)") + case CaseWhen(branches, elseValue) => + val conditionsSQL = branches.map(_._1).flatMap(generateSQL) + val valuesSQL = branches.map(_._2).flatMap(generateSQL) + if (conditionsSQL.length == branches.length && valuesSQL.length == branches.length) { + val branchSQL = + conditionsSQL.zip(valuesSQL).map { case (c, v) => s" WHEN $c THEN $v" }.mkString + if (elseValue.isDefined) { + elseValue.flatMap(generateSQL).map(v => s"CASE$branchSQL ELSE $v END") + } else { + Some(s"CASE$branchSQL END") + } + } else { + None + } + // TODO supports other expressions + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala index 6340d97af1a04..6d8cae544f23e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -42,27 +42,28 @@ object AggregatePushDownUtils { var finalSchema = new StructType() - def getStructFieldForCol(col: NamedReference): StructField = { - schema.apply(col.fieldNames.head) + def getStructFieldForCol(colName: String): StructField = { + schema.apply(colName) } - def isPartitionCol(col: NamedReference) = { - partitionNames.contains(col.fieldNames.head) + def isPartitionCol(colName: String) = { + partitionNames.contains(colName) } def processMinOrMax(agg: AggregateFunc): Boolean = { - val (column, aggType) = agg match { - case max: Max => (max.column, "max") - case min: Min => (min.column, "min") - case _ => - throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") + val (columnName, aggType) = agg match { + case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined => + (V2ColumnUtils.extractV2Column(max.column).get, "max") + case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined => + (V2ColumnUtils.extractV2Column(min.column).get, "min") + case _ => return false } - if (isPartitionCol(column)) { + if (isPartitionCol(columnName)) { // don't push down partition column, footer doesn't have max/min for partition column return false } - val structField = getStructFieldForCol(column) + val structField = getStructFieldForCol(columnName) structField.dataType match { // not push down complex type @@ -93,16 +94,22 @@ object AggregatePushDownUtils { // (https://issues.apache.org/jira/browse/SPARK-36646) return None } + aggregation.groupByColumns.foreach { col => + // don't push down if the group by columns are not the same as the partition columns (orders + // doesn't matter because reorder can be done at data source layer) + if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None + finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head)) + } aggregation.aggregateExpressions.foreach { case max: Max => if (!processMinOrMax(max)) return None case min: Min => if (!processMinOrMax(min)) return None - case count: Count => - if (count.column.fieldNames.length != 1 || count.isDistinct) return None - finalSchema = - finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) + case count: Count + if V2ColumnUtils.extractV2Column(count.column).isDefined && !count.isDistinct => + val columnName = V2ColumnUtils.extractV2Column(count.column).get + finalSchema = finalSchema.add(StructField(s"count($columnName)", LongType)) case _: CountStar => finalSchema = finalSchema.add(StructField("count(*)", LongType)) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1934ef9f03228..29c73ba0cf59c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -38,9 +38,10 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.catalyst.util.ExpressionSQLBuilder import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} +import org.apache.spark.sql.connector.expressions.{Expression => ExpressionV2, FieldReference, GeneralSQLExpression, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -699,46 +700,44 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } - protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { - if (aggregates.filter.isEmpty) { - aggregates.aggregateFunction match { - case aggregate.Min(PushableColumnWithoutNestedColumn(name)) => - Some(new Min(FieldReference(name))) - case aggregate.Max(PushableColumnWithoutNestedColumn(name)) => - Some(new Max(FieldReference(name))) + protected[sql] def translateAggregate(agg: AggregateExpression): Option[AggregateFunc] = { + if (agg.filter.isEmpty) { + agg.aggregateFunction match { + case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) + case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) case count: aggregate.Count if count.children.length == 1 => count.children.head match { - // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table + // COUNT(any literal) is the same as COUNT(*) case Literal(_, _) => Some(new CountStar()) - case PushableColumnWithoutNestedColumn(name) => - Some(new Count(FieldReference(name), aggregates.isDistinct)) + case PushableExpression(expr) => Some(new Count(expr, agg.isDistinct)) case _ => None } - case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => - Some(new Sum(FieldReference(name), aggregates.isDistinct)) - case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => - Some(new Avg(FieldReference(name), aggregates.isDistinct)) + case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, agg.isDistinct)) + case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, agg.isDistinct)) case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("VAR_POP", aggregates.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc( + "VAR_POP", agg.isDistinct, Array(FieldReference(name)))) case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("VAR_SAMP", aggregates.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc( + "VAR_SAMP", agg.isDistinct, Array(FieldReference(name)))) case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("STDDEV_POP", aggregates.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc( + "STDDEV_POP", agg.isDistinct, Array(FieldReference(name)))) case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) => - Some(new GeneralAggregateFunc("STDDEV_SAMP", aggregates.isDistinct, Array(FieldReference(name)))) + Some(new GeneralAggregateFunc( + "STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name)))) case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left), PushableColumnWithoutNestedColumn(right), _) => - Some(new GeneralAggregateFunc("COVAR_POP", aggregates.isDistinct, + Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct, Array(FieldReference(left), FieldReference(right)))) case aggregate.CovSample(PushableColumnWithoutNestedColumn(left), PushableColumnWithoutNestedColumn(right), _) => - Some(new GeneralAggregateFunc("COVAR_SAMP", aggregates.isDistinct, + Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct, Array(FieldReference(left), FieldReference(right)))) case aggregate.Corr(PushableColumnWithoutNestedColumn(left), PushableColumnWithoutNestedColumn(right), _) => - Some(new GeneralAggregateFunc("CORR", aggregates.isDistinct, + Some(new GeneralAggregateFunc("CORR", agg.isDistinct, Array(FieldReference(left), FieldReference(right)))) - case _ => None } } else { @@ -756,7 +755,7 @@ object DataSourceStrategy def columnAsString(e: Expression): Option[FieldReference] = e match { case PushableColumnWithoutNestedColumn(name) => - Some(FieldReference.column(name).asInstanceOf[FieldReference]) + Some(FieldReference(name).asInstanceOf[FieldReference]) case _ => None } @@ -854,3 +853,13 @@ object PushableColumnAndNestedColumn extends PushableColumnBase { object PushableColumnWithoutNestedColumn extends PushableColumnBase { override val nestedPredicatePushdownEnabled = false } + +/** + * Get the expression of DS V2 to represent catalyst expression that can be pushed down. + */ +object PushableExpression { + def unapply(e: Expression): Option[ExpressionV2] = e match { + case PushableColumnWithoutNestedColumn(name) => Some(FieldReference(name)) + case _ => new ExpressionSQLBuilder(e).build().map(new GeneralSQLExpression(_)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 4cf13a90ce829..7758d6a515b51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils} import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.SchemaMergeUtils +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -364,18 +365,18 @@ object OrcUtils extends Logging { val aggORCValues: Seq[WritableComparable[_]] = aggregation.aggregateExpressions.zipWithIndex.map { - case (max: Max, index) => - val columnName = max.column.fieldNames.head + case (max: Max, index) if V2ColumnUtils.extractV2Column(max.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(max.column).get val statistics = getColumnStatistics(columnName) val dataType = aggSchema(index).dataType getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) - case (min: Min, index) => - val columnName = min.column.fieldNames.head + case (min: Min, index) if V2ColumnUtils.extractV2Column(min.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(min.column).get val statistics = getColumnStatistics(columnName) val dataType = aggSchema.apply(index).dataType getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) - case (count: Count, _) => - val columnName = count.column.fieldNames.head + case (count: Count, _) if V2ColumnUtils.extractV2Column(count.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(count.column).get val isPartitionColumn = partitionSchema.fields.map(_.name).contains(columnName) // NOTE: Count(columnName) doesn't include null values. // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 0e4b9283d4866..f3836ab8b5ae4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} import org.apache.spark.sql.types.StructType @@ -230,33 +230,33 @@ object ParquetUtils { blocks.forEach { block => val blockMetaData = block.getColumns agg match { - case max: Max => - val colName = max.column.fieldNames.head + case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(max.column).get index = dataSchema.fieldNames.toList.indexOf(colName) schemaName = "max(" + colName + ")" val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true) if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { value = currentMax } - case min: Min => - val colName = min.column.fieldNames.head + case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(min.column).get index = dataSchema.fieldNames.toList.indexOf(colName) schemaName = "min(" + colName + ")" val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false) if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { value = currentMin } - case count: Count => - schemaName = "count(" + count.column.fieldNames.head + ")" + case count: Count if V2ColumnUtils.extractV2Column(count.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(count.column).get + schemaName = "count(" + colName + ")" rowCount += block.getRowCount var isPartitionCol = false - if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) - .toSet.contains(count.column.fieldNames.head)) { + if (partitionSchema.fields.map(_.name).toSet.contains(colName)) { isPartitionCol = true } isCount = true if (!isPartitionCol) { - index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) + index = dataSchema.fieldNames.toList.indexOf(colName) // Count(*) includes the null values, but Count(colName) doesn't. rowCount -= getNumNulls(filePath, blockMetaData, index) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala new file mode 100644 index 0000000000000..9fc220f440bc1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.connector.expressions.{Expression, NamedReference} + +object V2ColumnUtils { + def extractV2Column(expr: Expression): Option[String] = expr match { + case r: NamedReference if r. fieldNames.length == 1 => Some(r.fieldNames.head) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 23cdf25a86652..fe718668fa2de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, Timesta import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.{FieldReference, GeneralSQLExpression, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} @@ -203,28 +203,55 @@ abstract class JdbcDialect extends Serializable with Logging{ def compileAggregate(aggFunction: AggregateFunc): Option[String] = { aggFunction match { case min: Min => - if (min.column.fieldNames.length != 1) return None - Some(s"MIN(${quoteIdentifier(min.column.fieldNames.head)})") + val sql = min.column match { + case field: FieldReference => + if (field.fieldNames.length != 1) return None + quoteIdentifier(field.fieldNames.head) + case expr: GeneralSQLExpression => + expr.sql() + } + Some(s"MIN($sql)") case max: Max => - if (max.column.fieldNames.length != 1) return None - Some(s"MAX(${quoteIdentifier(max.column.fieldNames.head)})") + val sql = max.column match { + case field: FieldReference => + if (field.fieldNames.length != 1) return None + quoteIdentifier(field.fieldNames.head) + case expr: GeneralSQLExpression => + expr.sql() + } + Some(s"MAX($sql)") case count: Count => - if (count.column.fieldNames.length != 1) return None + val sql = count.column match { + case field: FieldReference => + if (field.fieldNames.length != 1) return None + quoteIdentifier(field.fieldNames.head) + case expr: GeneralSQLExpression => + expr.sql() + } val distinct = if (count.isDistinct) "DISTINCT " else "" - val column = quoteIdentifier(count.column.fieldNames.head) - Some(s"COUNT($distinct$column)") + Some(s"COUNT($distinct$sql)") case sum: Sum => - if (sum.column.fieldNames.length != 1) return None + val sql = sum.column match { + case field: FieldReference => + if (field.fieldNames.length != 1) return None + quoteIdentifier(field.fieldNames.head) + case expr: GeneralSQLExpression => + expr.sql() + } val distinct = if (sum.isDistinct) "DISTINCT " else "" - val column = quoteIdentifier(sum.column.fieldNames.head) - Some(s"SUM($distinct$column)") + Some(s"SUM($distinct$sql)") case _: CountStar => Some("COUNT(*)") case avg: Avg => - if (avg.column.fieldNames.length != 1) return None + val sql = avg.column match { + case field: FieldReference => + if (field.fieldNames.length != 1) return None + quoteIdentifier(field.fieldNames.head) + case expr: GeneralSQLExpression => + expr.sql() + } val distinct = if (avg.isDistinct) "DISTINCT " else "" - val column = quoteIdentifier(avg.column.fieldNames.head) - Some(s"AVG($distinct$column)") + Some(s"AVG($distinct$sql)") case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index a3d01e483209a..f8cb77757682a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -305,6 +305,34 @@ trait FileSourceAggregatePushDownSuite } } + test("aggregate not push down - MIN/MAX/COUNT with CASE WHEN") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + val selectAgg = sql( + """ + |SELECT + | min(CASE WHEN _1 < 0 THEN 0 ELSE _1 END), + | min(CASE WHEN _3 > 5 THEN 1 ELSE 0 END), + | max(CASE WHEN _1 < 0 THEN 0 ELSE _1 END), + | max(CASE WHEN NOT(_3 > 5) THEN 1 ELSE 0 END), + | count(CASE WHEN _1 < 0 AND _2 IS NOT NULL THEN 0 ELSE _1 END), + | count(CASE WHEN _3 != 5 OR _2 IS NULL THEN 1 ELSE 0 END) + |FROM t + """.stripMargin) + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + + checkAnswer(selectAgg, Seq(Row(0, 0, 9, 1, 6, 6))) + } + } + } + private def testPushDownForAllDataTypes( inputRows: Seq[Row], expectedMinWithAllTypes: Seq[Row], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a65d689385e63..933c464da678f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -791,17 +791,72 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row(53000.00))) } - test("scan with aggregate push-down: SUM(CASE WHEN) with group by") { - val df = - sql("SELECT SUM(CASE WHEN SALARY > 0 THEN 1 ELSE 0 END) FROM h2.test.employee GROUP BY DEPT") - checkAggregateRemoved(df, false) + test("scan with aggregate push-down: aggregate with partially pushed down filters" + + "will NOT push down") { + val df = spark.table("h2.test.employee") + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val query = df.select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter("SALARY > 100") + .filter(name($"shortName")) + .agg(sum($"SALARY").as("sum_salary")) + checkAggregateRemoved(query, false) + query.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.aggregation.isEmpty) + } + } + checkAnswer(query, Seq(Row(29000.0))) + } + + test("scan with aggregate push-down: aggregate function with CASE WHEN") { + val df = sql( + """ + |SELECT + | COUNT(CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY > 8000 AND SALARY <= 13000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY > 11000 OR SALARY < 10000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY >= 12000 OR SALARY < 9000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY >= 12000 OR NOT(SALARY >= 9000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000) AND SALARY >= 8000 THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000) OR SALARY > 8000 THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000) AND NOT(SALARY < 8000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY != 0) OR NOT(SALARY < 8000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000 AND SALARY > 8000) THEN 0 ELSE SALARY END), + | MIN(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NULL) THEN SALARY ELSE 0 END), + | SUM(CASE WHEN NOT(SALARY > 8000 AND SALARY IS NOT NULL) THEN SALARY ELSE 0 END), + | SUM(CASE WHEN SALARY > 10000 THEN 2 WHEN SALARY > 8000 THEN 1 END), + | AVG(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NOT NULL) THEN SALARY ELSE 0 END) + |FROM h2.test.employee GROUP BY DEPT + """.stripMargin) + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedFilters: [], " + "PushedAggregates: [COUNT(CASE WHEN ((SALARY) > (8000.00)) AND ((SALARY) < (10000.00))" + + " THEN SALARY ELSE 0.00 END), C..., " + + "PushedFilters: [], " + + "PushedGroupByColumns: [DEPT]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } - checkAnswer(df, Seq(Row(1), Row(2), Row(2))) + checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 2, 0d), + Row(2, 2, 2, 2, 2, 0d, 10000d, 0d, 10000d, 10000d, 0d, 0d, 2, 0d), + Row(2, 2, 2, 2, 2, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 3, 0d))) + } + + test("scan with aggregate push-down: aggregate function with UDF") { + val df = spark.table("h2.test.employee") + val decrease = udf { (x: Double, y: Double) => x - y } + val query = df.select(sum(decrease($"SALARY", $"BONUS")).as("value")) + checkAggregateRemoved(query, false) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: []" + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(47100.0))) } test("scan with aggregate push-down: partition columns with multi group by columns") { From f9b54fbd83f8966de6bddbe786b47bd99d26702b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 18 Feb 2022 22:22:04 +0800 Subject: [PATCH 29/53] [SPARK-37867][SQL][FOLLOWUP] Compile aggregate functions for build-in DB2 dialect ### What changes were proposed in this pull request? This PR follows up https://github.com/apache/spark/pull/35166. The previously referenced DB2 documentation is incorrect, resulting in the lack of compile that supports some aggregate functions. The correct documentation is https://www.ibm.com/docs/en/db2/11.5?topic=af-regression-functions-regr-avgx-regr-avgy-regr-count ### Why are the changes needed? Make build-in DB2 dialect support complete aggregate push-down more aggregate functions. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users could use complete aggregate push-down with build-in DB2 dialect. ### How was this patch tested? New tests. Closes #35520 from beliefer/SPARK-37867_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../sql/jdbc/v2/DB2IntegrationSuite.scala | 9 +++ .../jdbc/v2/MsSqlServerIntegrationSuite.scala | 4 ++ .../jdbc/v2/PostgresIntegrationSuite.scala | 7 +++ .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 63 ++++++++++--------- .../apache/spark/sql/jdbc/DB2Dialect.scala | 19 ++++++ .../apache/spark/sql/jdbc/DerbyDialect.scala | 23 +++---- .../spark/sql/jdbc/MsSqlServerDialect.scala | 3 + .../apache/spark/sql/jdbc/MySQLDialect.scala | 21 +++---- .../apache/spark/sql/jdbc/OracleDialect.scala | 38 +++++------ .../spark/sql/jdbc/PostgresDialect.scala | 1 + .../spark/sql/jdbc/TeradataDialect.scala | 18 +++--- 11 files changed, 123 insertions(+), 83 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index d0479e9032e06..35711e57d0b72 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -97,4 +97,13 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) testVarPop() + testVarPop(true) + testVarSamp() + testVarSamp(true) + testStddevPop() + testStddevPop(true) + testStddevSamp() + testStddevSamp(true) + testCovarPop() + testCovarSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 536eb465ceb11..4df5f4525a0fa 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -97,7 +97,11 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD } testVarPop() + testVarPop(true) testVarSamp() + testVarSamp(true) testStddevPop() + testStddevPop(true) testStddevSamp() + testStddevSamp(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index b3004e1c21c89..d76e13c1cd421 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -91,10 +91,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT override def indexOptions: String = "FILLFACTOR=70" testVarPop() + testVarPop(true) testVarSamp() + testVarSamp(true) testStddevPop() + testStddevPop(true) testStddevSamp() + testStddevSamp(true) testCovarPop() + testCovarPop(true) testCovarSamp() + testCovarSamp(true) testCorr() + testCorr(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 667579b20eaf7..7cab8cd77df66 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -386,10 +386,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu protected def caseConvert(tableName: String): String = tableName - protected def testVarPop(): Unit = { - test(s"scan with aggregate push-down: VAR_POP") { - val df = sql(s"SELECT VAR_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + protected def testVarPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct") { + val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "VAR_POP") @@ -401,11 +402,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testVarSamp(): Unit = { - test(s"scan with aggregate push-down: VAR_SAMP") { + protected def testVarSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: VAR_SAMP with distinct: $isDistinct") { val df = sql( - s"SELECT VAR_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "VAR_SAMP") @@ -417,11 +419,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testStddevPop(): Unit = { - test("scan with aggregate push-down: STDDEV_POP") { + protected def testStddevPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: STDDEV_POP with distinct: $isDistinct") { val df = sql( - s"SELECT STDDEV_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "STDDEV_POP") @@ -433,11 +436,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testStddevSamp(): Unit = { - test("scan with aggregate push-down: STDDEV_SAMP") { + protected def testStddevSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: STDDEV_SAMP with distinct: $isDistinct") { val df = sql( - s"SELECT STDDEV_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "STDDEV_SAMP") @@ -449,11 +453,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testCovarPop(): Unit = { - test("scan with aggregate push-down: COVAR_POP") { + protected def testCovarPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: COVAR_POP with distinct: $isDistinct") { val df = sql( - s"SELECT COVAR_POP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "COVAR_POP") @@ -465,11 +470,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testCovarSamp(): Unit = { - test("scan with aggregate push-down: COVAR_SAMP") { + protected def testCovarSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: COVAR_SAMP with distinct: $isDistinct") { val df = sql( - s"SELECT COVAR_SAMP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "COVAR_SAMP") @@ -481,11 +487,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testCorr(): Unit = { - test("scan with aggregate push-down: CORR") { + protected def testCorr(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") { val df = sql( - s"SELECT CORR(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "CORR") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index ffda7545c6e9f..dd68953badf7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -30,6 +30,7 @@ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") + // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { @@ -37,6 +38,24 @@ private object DB2Dialect extends JdbcDialect { assert(f.inputs().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" Some(s"VARIANCE($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARIANCE_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVARIANCE(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVARIANCE_SAMP(${f.inputs().head}, ${f.inputs().last})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index e87d4d08ae031..bf838b8ed66eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -30,25 +30,22 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") + // See https://db.apache.org/derby/docs/10.15/ref/index.html override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + Some(s"STDDEV_SAMP(${f.inputs().head})") case _ => None } ) @@ -72,7 +69,7 @@ private object DerbyDialect extends JdbcDialect { override def isCascadingTruncateTable(): Option[Boolean] = Some(false) - // See https://db.apache.org/derby/docs/10.5/ref/rrefsqljrenametablestatement.html + // See https://db.apache.org/derby/docs/10.15/ref/rrefsqljrenametablestatement.html override def renameTable(oldTable: String, newTable: String): String = { s"RENAME TABLE $oldTable TO $newTable" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 3d8a48a66ea8f..841f1c87319b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -40,6 +40,9 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver") + // scalastyle:off line.size.limit + // See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15 + // scalastyle:on line.size.limit override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index c32499b5f32e1..b1093a4f2f7c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -38,25 +38,22 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") + // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + Some(s"STDDEV_SAMP(${f.inputs().head})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 4fe7d93142c1e..71db7e9285f5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -34,37 +34,33 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") + // scalastyle:off line.size.limit + // https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848 + // scalastyle:on line.size.limit override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + Some(s"STDDEV_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => + Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"CORR(${f.inputs().head}, ${f.inputs().last})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 46e79404f3e54..e2023d110ae4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -36,6 +36,7 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") + // See https://www.postgresql.org/docs/8.4/functions-aggregate.html override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 6344667b3180e..13e16d24d048d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -28,6 +28,9 @@ private case object TeradataDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata") + // scalastyle:off line.size.limit + // See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions + // scalastyle:on line.size.limit override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { @@ -47,18 +50,15 @@ private case object TeradataDialect extends JdbcDialect { assert(f.inputs().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" Some(s"STDDEV_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => + Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"CORR(${f.inputs().head}, ${f.inputs().last})") case _ => None } ) From 16cb31977f89676b060137d66a2f417d8e292f6b Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 26 Aug 2021 11:26:39 +0800 Subject: [PATCH 30/53] [SPARK-36568][SQL] Better FileScan statistics estimation ### What changes were proposed in this pull request? This PR modifies `FileScan.estimateStatistics()` to take the read schema into account. ### Why are the changes needed? `V2ScanRelationPushDown` can column prune `DataSourceV2ScanRelation`s and change read schema of `Scan` operations. The better statistics returned by `FileScan.estimateStatistics()` can mean better query plans. For example, with this change the broadcast issue in SPARK-36568 can be avoided. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added new UT. Closes #33825 from peter-toth/SPARK-36568-scan-statistics-estimation. Authored-by: Peter Toth Signed-off-by: Wenchen Fan --- .../execution/datasources/v2/FileScan.scala | 7 +++++- .../datasources/v2/text/TextScan.scala | 1 + .../spark/sql/FileBasedDataSourceSuite.scala | 22 +++++++++++++++++++ .../org/apache/spark/sql/FileScanSuite.scala | 2 +- .../apache/spark/sql/test/SQLTestUtils.scala | 5 ++++- 5 files changed, 34 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 0212cdf63fcf9..8b0328cabc5a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -49,6 +49,8 @@ trait FileScan extends Scan def fileIndex: PartitioningAwareFileIndex + def dataSchema: StructType + /** * Returns the required data schema */ @@ -181,7 +183,10 @@ trait FileScan extends Scan new Statistics { override def sizeInBytes(): OptionalLong = { val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor - val size = (compressionFactor * fileIndex.sizeInBytes).toLong + val size = (compressionFactor * fileIndex.sizeInBytes / + (dataSchema.defaultSize + fileIndex.partitionSchema.defaultSize) * + (readDataSchema.defaultSize + readPartitionSchema.defaultSize)).toLong + OptionalLong.of(size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index 3582978a8c569..c7b0fec34b4e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -33,6 +33,7 @@ import org.apache.spark.util.SerializableConfiguration case class TextScan( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 001b6a00af52f..910f159cc49a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -731,6 +731,28 @@ class FileBasedDataSourceSuite extends QueryTest } } + test("SPARK-36568: FileScan statistics estimation takes read schema into account") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempDir { dir => + spark.range(1000).map(x => (x / 100, x, x)).toDF("k", "v1", "v2"). + write.partitionBy("k").mode(SaveMode.Overwrite).orc(dir.toString) + val dfAll = spark.read.orc(dir.toString) + val dfK = dfAll.select("k") + val dfV1 = dfAll.select("v1") + val dfV2 = dfAll.select("v2") + val dfV1V2 = dfAll.select("v1", "v2") + + def sizeInBytes(df: DataFrame): BigInt = df.queryExecution.optimizedPlan.stats.sizeInBytes + + assert(sizeInBytes(dfAll) === BigInt(getLocalDirSize(dir))) + assert(sizeInBytes(dfK) < sizeInBytes(dfAll)) + assert(sizeInBytes(dfV1) < sizeInBytes(dfAll)) + assert(sizeInBytes(dfV2) === sizeInBytes(dfV1)) + assert(sizeInBytes(dfV1V2) < sizeInBytes(dfAll)) + } + } + } + test("File source v2: support partition pruning") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { allFileBasedDataSources.foreach { format => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index e213b32fdd243..14b59ba23d09f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -367,7 +367,7 @@ class FileScanSuite extends FileScanSuiteBase { (s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o, f, pf, df), Seq.empty), ("TextScan", - (s, fi, _, rds, rps, _, o, pf, df) => TextScan(s, fi, rds, rps, o, pf, df), + (s, fi, ds, rds, rps, _, o, pf, df) => TextScan(s, fi, ds, rds, rps, o, pf, df), Seq("dataSchema", "pushedFilters"))) run(scanBuilders) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 0e62be40607a1..ba0b599f2245d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -22,6 +22,7 @@ import java.net.URI import java.nio.file.Files import java.util.{Locale, UUID} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.util.control.NonFatal @@ -459,7 +460,9 @@ private[sql] trait SQLTestUtilsBase */ def getLocalDirSize(file: File): Long = { assert(file.isDirectory) - file.listFiles.filter(f => DataSourceUtils.isDataFile(f.getName)).map(_.length).sum + Files.walk(file.toPath).iterator().asScala + .filter(p => Files.isRegularFile(p) && DataSourceUtils.isDataFile(p.getFileName.toString)) + .map(_.toFile.length).sum } } From 659f15f02cdc72b5522d7716bb28ae4d81cb5e1b Mon Sep 17 00:00:00 2001 From: dch nguyen Date: Fri, 21 Jan 2022 15:23:05 +0800 Subject: [PATCH 31/53] [SPARK-37929][SQL] Support cascade mode for `dropNamespace` API ### What changes were proposed in this pull request? This PR adds a new API `dropNamespace(String[] ns, boolean cascade)` to replace the existing one: Add a boolean parameter `cascade` that supports deleting all the Namespaces and Tables under the namespace. Also include changing the implementations and tests that are relevant to this API. ### Why are the changes needed? According to [#cmt](https://github.com/apache/spark/pull/35202#discussion_r784463563), the current `dropNamespace` API doesn't support cascade mode. So this PR replaces that to support cascading. If cascade is set True, delete all namespaces and tables under the namespace. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test. Closes #35246 from dchvn/change_dropnamespace_api. Authored-by: dch nguyen Signed-off-by: Wenchen Fan --- .../catalog/DelegatingCatalogExtension.java | 6 ++- .../connector/catalog/SupportsNamespaces.java | 10 ++++- .../catalyst/analysis/NonEmptyException.scala | 36 ++++++++++++++++ .../sql/errors/QueryCompilationErrors.scala | 5 +++ .../sql/connector/catalog/CatalogSuite.scala | 6 +-- .../catalog/InMemoryTableCatalog.scala | 14 +++++-- .../datasources/v2/DropNamespaceExec.scala | 19 ++++----- .../datasources/v2/V2SessionCatalog.scala | 9 ++-- .../v2/jdbc/JDBCTableCatalog.scala | 4 +- .../v2/V2SessionCatalogSuite.scala | 42 +++++++++---------- 10 files changed, 101 insertions(+), 50 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java index 34f07b12b3666..5edf51969d646 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java @@ -147,8 +147,10 @@ public void alterNamespace( } @Override - public boolean dropNamespace(String[] namespace) throws NoSuchNamespaceException { - return asNamespaceCatalog().dropNamespace(namespace); + public boolean dropNamespace( + String[] namespace, + boolean cascade) throws NoSuchNamespaceException, NonEmptyNamespaceException { + return asNamespaceCatalog().dropNamespace(namespace, cascade); } private TableCatalog asTableCatalog() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java index f70746b612e92..c1a4960068d24 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException; import java.util.Map; @@ -136,15 +137,20 @@ void alterNamespace( NamespaceChange... changes) throws NoSuchNamespaceException; /** - * Drop a namespace from the catalog, recursively dropping all objects within the namespace. + * Drop a namespace from the catalog with cascade mode, recursively dropping all objects + * within the namespace if cascade is true. *

    * If the catalog implementation does not support this operation, it may throw * {@link UnsupportedOperationException}. * * @param namespace a multi-part namespace + * @param cascade When true, deletes all objects under the namespace * @return true if the namespace was dropped * @throws NoSuchNamespaceException If the namespace does not exist (optional) + * @throws NonEmptyNamespaceException If the namespace is non-empty and cascade is false * @throws UnsupportedOperationException If drop is not a supported operation */ - boolean dropNamespace(String[] namespace) throws NoSuchNamespaceException; + boolean dropNamespace( + String[] namespace, + boolean cascade) throws NoSuchNamespaceException, NonEmptyNamespaceException; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala new file mode 100644 index 0000000000000..f3ff28f74fcc3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + +/** + * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception + * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + */ +case class NonEmptyNamespaceException( + override val message: String, + override val cause: Option[Throwable] = None) + extends AnalysisException(message, cause = cause) { + + def this(namespace: Array[String]) = { + this(s"Namespace '${namespace.quoted}' is non empty.") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index ef262a88b7ecb..88c00c02597e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -555,6 +555,11 @@ object QueryCompilationErrors { new AnalysisException(s"Database $db is not empty. One or more $details exist.") } + def cannotDropNonemptyNamespaceError(namespace: Seq[String]): Throwable = { + new AnalysisException(s"Cannot drop a non-empty namespace: ${namespace.quoted}. " + + "Use CASCADE option to drop a non-empty namespace.") + } + def invalidNameForTableOrDatabaseError(name: String): Throwable = { new AnalysisException(s"`$name` is not a valid name for tables/databases. " + "Valid names only contain alphabet characters, numbers and _.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index 0cca1cc9bebf2..d00bc31e07f19 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -820,7 +820,7 @@ class CatalogSuite extends SparkFunSuite { assert(catalog.namespaceExists(testNs) === false) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === false) } @@ -833,7 +833,7 @@ class CatalogSuite extends SparkFunSuite { assert(catalog.namespaceExists(testNs) === true) assert(catalog.loadNamespaceMetadata(testNs).asScala === Map("property" -> "value")) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === true) assert(catalog.namespaceExists(testNs) === false) @@ -845,7 +845,7 @@ class CatalogSuite extends SparkFunSuite { catalog.createNamespace(testNs, Map("property" -> "value").asJava) catalog.createTable(testIdent, schema, Array.empty, emptyProps) - assert(catalog.dropNamespace(testNs)) + assert(catalog.dropNamespace(testNs, cascade = true)) assert(!catalog.namespaceExists(testNs)) intercept[NoSuchNamespaceException](catalog.listTables(testNs)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 0c403baca2113..41063a41b9719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NonEmptyNamespaceException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.types.StructType @@ -193,10 +193,16 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp namespaces.put(namespace.toList, CatalogV2Util.applyNamespaceChanges(metadata, changes)) } - override def dropNamespace(namespace: Array[String]): Boolean = { - listNamespaces(namespace).foreach(dropNamespace) + override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = { try { - listTables(namespace).foreach(dropTable) + if (!cascade) { + if (listTables(namespace).nonEmpty || listNamespaces(namespace).nonEmpty) { + throw new NonEmptyNamespaceException(namespace) + } + } else { + listNamespaces(namespace).foreach(namespace => dropNamespace(namespace, cascade)) + listTables(namespace).foreach(dropTable) + } } catch { case _: NoSuchNamespaceException => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala index dbd5cbd874945..5d302055e7d91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.catalog.CatalogPlugin -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.errors.QueryCompilationErrors /** * Physical plan node for dropping a namespace. @@ -37,17 +38,11 @@ case class DropNamespaceExec( val nsCatalog = catalog.asNamespaceCatalog val ns = namespace.toArray if (nsCatalog.namespaceExists(ns)) { - // The default behavior of `SupportsNamespace.dropNamespace()` is cascading, - // so make sure the namespace to drop is empty. - if (!cascade) { - if (catalog.asTableCatalog.listTables(ns).nonEmpty - || nsCatalog.listNamespaces(ns).nonEmpty) { - throw QueryExecutionErrors.cannotDropNonemptyNamespaceError(namespace) - } - } - - if (!nsCatalog.dropNamespace(ns)) { - throw QueryExecutionErrors.cannotDropNonemptyNamespaceError(namespace) + try { + nsCatalog.dropNamespace(ns, cascade) + } catch { + case _: NonEmptyNamespaceException => + throw QueryCompilationErrors.cannotDropNonemptyNamespaceError(namespace) } } else if (!ifExists) { throw QueryCompilationErrors.noSuchNamespaceError(ns) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index d4a981d2205da..fe91cc486967b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -261,12 +261,11 @@ class V2SessionCatalog(catalog: SessionCatalog) } } - override def dropNamespace(namespace: Array[String]): Boolean = namespace match { + override def dropNamespace( + namespace: Array[String], + cascade: Boolean): Boolean = namespace match { case Array(db) if catalog.databaseExists(db) => - if (catalog.listTables(db).nonEmpty) { - throw QueryExecutionErrors.namespaceNotEmptyError(namespace) - } - catalog.dropDatabase(db, ignoreIfNotExists = false, cascade = false) + catalog.dropDatabase(db, ignoreIfNotExists = false, cascade) true case Array(_) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index f311cf63d1419..03200d5a6f371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -272,7 +272,9 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } } - override def dropNamespace(namespace: Array[String]): Boolean = namespace match { + override def dropNamespace( + namespace: Array[String], + cascade: Boolean): Boolean = namespace match { case Array(db) if namespaceExists(namespace) => JdbcUtils.withConnection(options) { conn => JdbcUtils.classifyException(s"Failed drop name space: $db", dialect) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index 1a4f08418f8d3..1a52dc4da009f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -67,10 +67,10 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { override protected def afterAll(): Unit = { val catalog = newCatalog() - catalog.dropNamespace(Array("db")) - catalog.dropNamespace(Array("db2")) - catalog.dropNamespace(Array("ns")) - catalog.dropNamespace(Array("ns2")) + catalog.dropNamespace(Array("db"), cascade = true) + catalog.dropNamespace(Array("db2"), cascade = true) + catalog.dropNamespace(Array("ns"), cascade = true) + catalog.dropNamespace(Array("ns2"), cascade = true) super.afterAll() } @@ -806,7 +806,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.listNamespaces(Array()) === Array(testNs, defaultNs)) assert(catalog.listNamespaces(testNs) === Array()) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("listNamespaces: fail if missing namespace") { @@ -844,7 +844,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(metadata.asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("loadNamespaceMetadata: empty metadata") { @@ -859,7 +859,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(metadata.asScala, emptyProps.asScala) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: basic behavior") { @@ -879,7 +879,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map("property" -> "value")) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: initialize location") { @@ -895,7 +895,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map.empty) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: relative location") { @@ -912,7 +912,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map.empty) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: fail if namespace already exists") { @@ -928,7 +928,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: fail nested namespace") { @@ -943,7 +943,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(exc.getMessage.contains("Invalid namespace name: db.nested")) - catalog.dropNamespace(Array("db")) + catalog.dropNamespace(Array("db"), cascade = false) } test("createTable: fail if namespace does not exist") { @@ -964,7 +964,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === false) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === false) } @@ -976,7 +976,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === true) assert(catalog.namespaceExists(testNs) === false) @@ -988,8 +988,8 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.createNamespace(testNs, Map("property" -> "value").asJava) catalog.createTable(testIdent, schema, Array.empty, emptyProps) - val exc = intercept[IllegalStateException] { - catalog.dropNamespace(testNs) + val exc = intercept[AnalysisException] { + catalog.dropNamespace(testNs, cascade = false) } assert(exc.getMessage.contains(testNs.quoted)) @@ -997,7 +997,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) catalog.dropTable(testIdent) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: basic behavior") { @@ -1022,7 +1022,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: update namespace location") { @@ -1045,7 +1045,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.alterNamespace(testNs, NamespaceChange.setProperty("location", "relativeP")) assert(newRelativePath === spark.catalog.getDatabase(testNs(0)).locationUri) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: update namespace comment") { @@ -1060,7 +1060,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(newComment === spark.catalog.getDatabase(testNs(0)).description) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: fail if namespace doesn't exist") { @@ -1087,6 +1087,6 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(exc.getMessage.contains(s"Cannot remove reserved property: $p")) } - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } } From 7e5c9ba8a4e9424349ee57d42967dd3783767960 Mon Sep 17 00:00:00 2001 From: chenzhx Date: Tue, 22 Feb 2022 20:33:10 +0800 Subject: [PATCH 32/53] code format --- .../connector/catalog/DelegatingCatalogExtension.java | 5 +---- .../sql/connector/catalog/index/SupportsIndex.java | 11 +++++++---- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 7 +++---- .../sql/execution/datasources/v2/jdbc/JDBCTable.scala | 9 ++++----- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 6 ++---- .../org/apache/spark/sql/jdbc/MySQLDialect.scala | 4 ++-- .../scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 7 +++++++ 7 files changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java index 5edf51969d646..66e8a431458f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java @@ -20,10 +20,7 @@ import java.util.Map; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; -import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; -import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.*; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java index 4181cf5f25118..1419e975f5695 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java @@ -34,21 +34,24 @@ @Evolving public interface SupportsIndex extends Table { + /** + * A reserved property to specify the index type. + */ + String PROP_TYPE = "type"; + /** * Creates an index. * * @param indexName the name of the index to be created - * @param indexType the IndexType of the index to be created * @param columns the columns on which index to be created * @param columnsProperties the properties of the columns on which index to be created * @param properties the properties of the index to be created * @throws IndexAlreadyExistsException If the index already exists. */ void createIndex(String indexName, - String indexType, NamedReference[] columns, - Map[] columnsProperties, - Properties properties) + Map> columnsProperties, + Map properties) throws IndexAlreadyExistsException; /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b554814f1e193..da627a16099c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -1035,15 +1035,14 @@ object JdbcUtils extends Logging { def createIndex( conn: Connection, indexName: String, - indexType: String, tableName: String, columns: Array[NamedReference], - columnsProperties: Array[util.Map[NamedReference, util.Properties]], - properties: util.Properties, + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String], options: JDBCOptions): Unit = { val dialect = JdbcDialects.get(options.url) executeStatement(conn, options, - dialect.createIndex(indexName, indexType, tableName, columns, columnsProperties, properties)) + dialect.createIndex(indexName, tableName, columns, columnsProperties, properties)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala index ba56643f4d980..793b72727b9ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala @@ -54,15 +54,14 @@ case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOpt override def createIndex( indexName: String, - indexType: String, columns: Array[NamedReference], - columnsProperties: Array[util.Map[NamedReference, util.Properties]], - properties: util.Properties): Unit = { + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String]): Unit = { JdbcUtils.withConnection(jdbcOptions) { conn => - JdbcUtils.classifyException(s"Failed to create index: $indexName in $name", + JdbcUtils.classifyException(s"Failed to create index $indexName in $name", JdbcDialects.get(jdbcOptions.url)) { JdbcUtils.createIndex( - conn, indexName, indexType, name, columns, columnsProperties, properties, jdbcOptions) + conn, indexName, name, columns, columnsProperties, properties, jdbcOptions) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index fe718668fa2de..2d10bbf5de537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -405,7 +405,6 @@ abstract class JdbcDialect extends Serializable with Logging{ * Build a create index SQL statement. * * @param indexName the name of the index to be created - * @param indexType the type of the index to be created * @param tableName the table on which index to be created * @param columns the columns on which index to be created * @param columnsProperties the properties of the columns on which index to be created @@ -414,11 +413,10 @@ abstract class JdbcDialect extends Serializable with Logging{ */ def createIndex( indexName: String, - indexType: String, tableName: String, columns: Array[NamedReference], - columnsProperties: Array[util.Map[NamedReference, util.Properties]], - properties: util.Properties): String = { + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String]): String = { throw new UnsupportedOperationException("createIndex is not supported") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b1093a4f2f7c6..d73721de962d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -166,8 +166,8 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { indexName: String, tableName: String, columns: Array[NamedReference], - columnsProperties: Array[util.Map[NamedReference, util.Properties]], - properties: util.Properties): String = { + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String]): String = { val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head)) val (indexType, indexPropertyList) = JdbcUtils.processIndexProperties(properties, "mysql") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 933c464da678f..c6d0a87787639 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -418,6 +418,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } + private def checkFiltersRemoved(df: DataFrame): Unit = { + val filters = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + assert(filters.isEmpty) + } + test("scan with aggregate push-down: MAX AVG with filter without group by") { val df = sql("select MAX(ID), AVG(ID) FROM h2.test.people where id > 0") val filters = df.queryExecution.optimizedPlan.collect { From 04fef08e0b36f2faa122a07900354401fe92f659 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 4 Mar 2022 21:23:45 +0800 Subject: [PATCH 33/53] [SPARK-38196][SQL] Refactor framework so as JDBC dialect could compile expression by self way ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/35248 provides a new framework to represent catalyst expressions in DS V2 APIs. Because the framework translate all catalyst expressions to a unified SQL string and cannot keep compatibility between different JDBC database, the framework works not good. This PR reactor the framework so as JDBC dialect could compile expression by self way. First, The framework translate catalyst expressions to DS V2 expression. Second, The JDBC dialect could compile DS V2 expression to different SQL syntax. The java doc looks show below: ![image](https://user-images.githubusercontent.com/8486025/156579584-f56cafb5-641f-4c5b-a06e-38f4369051c3.png) ### Why are the changes needed? Make the framework be more common use. ### Does this PR introduce _any_ user-facing change? 'No'. The feature is not released. ### How was this patch tested? Exists tests. Closes #35494 from beliefer/SPARK-37960_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../expressions/GeneralSQLExpression.java | 41 ---- .../expressions/GeneralScalarExpression.java | 203 ++++++++++++++++++ .../util/V2ExpressionSQLBuilder.java | 151 +++++++++++++ .../catalyst/util/ExpressionSQLBuilder.scala | 69 ------ .../catalyst/util/V2ExpressionBuilder.scala | 94 ++++++++ .../datasources/DataSourceStrategy.scala | 12 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 73 +++---- ...SourceV2DataFrameSessionCatalogSuite.scala | 4 +- .../connector/DataSourceV2FunctionSuite.scala | 3 +- .../sql/connector/DataSourceV2Suite.scala | 7 +- .../spark/sql/connector/LocalScanSuite.scala | 9 +- .../connector/SimpleWritableDataSource.scala | 5 +- .../connector/TableCapabilityCheckSuite.scala | 10 +- .../connector/TestV2SessionCatalogBase.scala | 9 +- .../sql/connector/V1ReadFallbackSuite.scala | 10 +- .../sql/connector/V1WriteFallbackSuite.scala | 10 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 31 ++- 17 files changed, 542 insertions(+), 199 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java deleted file mode 100644 index ebeee22a853cf..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions; - -import java.io.Serializable; - -import org.apache.spark.annotation.Evolving; - -/** - * The general SQL string corresponding to expression. - * - * @since 3.3.0 - */ -@Evolving -public class GeneralSQLExpression implements Expression, Serializable { - private String sql; - - public GeneralSQLExpression(String sql) { - this.sql = sql; - } - - public String sql() { return sql; } - - @Override - public String toString() { return sql; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java new file mode 100644 index 0000000000000..b3dd2cbfe3d7d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; +import java.util.Arrays; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; + +// scalastyle:off line.size.limit +/** + * The general representation of SQL scalar expressions, which contains the upper-cased + * expression name and all the children expressions. + *

    + * The currently supported SQL scalar expressions: + *

      + *
    1. Name: IS_NULL + *
        + *
      • SQL semantic: expr IS NULL
      • + *
      • Since version: 3.3.0
      • + *
      + *
    2. + *
    3. Name: IS_NOT_NULL + *
        + *
      • SQL semantic: expr IS NOT NULL
      • + *
      • Since version: 3.3.0
      • + *
      + *
    4. + *
    5. Name: = + *
        + *
      • SQL semantic: expr1 = expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    6. + *
    7. Name: != + *
        + *
      • SQL semantic: expr1 != expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    8. + *
    9. Name: <> + *
        + *
      • SQL semantic: expr1 <> expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    10. + *
    11. Name: <=> + *
        + *
      • SQL semantic: expr1 <=> expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    12. + *
    13. Name: < + *
        + *
      • SQL semantic: expr1 < expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    14. + *
    15. Name: <= + *
        + *
      • SQL semantic: expr1 <= expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    16. + *
    17. Name: > + *
        + *
      • SQL semantic: expr1 > expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    18. + *
    19. Name: >= + *
        + *
      • SQL semantic: expr1 >= expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    20. + *
    21. Name: + + *
        + *
      • SQL semantic: expr1 + expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    22. + *
    23. Name: - + *
        + *
      • SQL semantic: expr1 - expr2 or - expr
      • + *
      • Since version: 3.3.0
      • + *
      + *
    24. + *
    25. Name: * + *
        + *
      • SQL semantic: expr1 * expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    26. + *
    27. Name: / + *
        + *
      • SQL semantic: expr1 / expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    28. + *
    29. Name: % + *
        + *
      • SQL semantic: expr1 % expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    30. + *
    31. Name: & + *
        + *
      • SQL semantic: expr1 & expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    32. + *
    33. Name: | + *
        + *
      • SQL semantic: expr1 | expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    34. + *
    35. Name: ^ + *
        + *
      • SQL semantic: expr1 ^ expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    36. + *
    37. Name: AND + *
        + *
      • SQL semantic: expr1 AND expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    38. + *
    39. Name: OR + *
        + *
      • SQL semantic: expr1 OR expr2
      • + *
      • Since version: 3.3.0
      • + *
      + *
    40. + *
    41. Name: NOT + *
        + *
      • SQL semantic: NOT expr
      • + *
      • Since version: 3.3.0
      • + *
      + *
    42. + *
    43. Name: ~ + *
        + *
      • SQL semantic: ~ expr
      • + *
      • Since version: 3.3.0
      • + *
      + *
    44. + *
    45. Name: CASE_WHEN + *
        + *
      • SQL semantic: + * CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END + *
      • + *
      • Since version: 3.3.0
      • + *
      + *
    46. + *
    + * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, + * including: add, subtract, multiply, divide, remainder, pmod. + * + * @since 3.3.0 + */ +// scalastyle:on line.size.limit +@Evolving +public class GeneralScalarExpression implements Expression, Serializable { + private String name; + private Expression[] children; + + public GeneralScalarExpression(String name, Expression[] children) { + this.name = name; + this.children = children; + } + + public String name() { return name; } + public Expression[] children() { return children; } + + @Override + public String toString() { + V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder(); + try { + return builder.build(this); + } catch (Throwable e) { + return name + "(" + + Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java new file mode 100644 index 0000000000000..0af0d88b0f622 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.util; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; +import org.apache.spark.sql.connector.expressions.LiteralValue; + +/** + * The builder to generate SQL from V2 expressions. + */ +public class V2ExpressionSQLBuilder { + public String build(Expression expr) { + if (expr instanceof LiteralValue) { + return visitLiteral((LiteralValue) expr); + } else if (expr instanceof FieldReference) { + return visitFieldReference((FieldReference) expr); + } else if (expr instanceof GeneralScalarExpression) { + GeneralScalarExpression e = (GeneralScalarExpression) expr; + String name = e.name(); + switch (name) { + case "IS_NULL": + return visitIsNull(build(e.children()[0])); + case "IS_NOT_NULL": + return visitIsNotNull(build(e.children()[0])); + case "=": + case "!=": + case "<=>": + case "<": + case "<=": + case ">": + case ">=": + return visitBinaryComparison(name, build(e.children()[0]), build(e.children()[1])); + case "+": + case "*": + case "/": + case "%": + case "&": + case "|": + case "^": + return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1])); + case "-": + if (e.children().length == 1) { + return visitUnaryArithmetic(name, build(e.children()[0])); + } else { + return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1])); + } + case "AND": + return visitAnd(name, build(e.children()[0]), build(e.children()[1])); + case "OR": + return visitOr(name, build(e.children()[0]), build(e.children()[1])); + case "NOT": + return visitNot(build(e.children()[0])); + case "~": + return visitUnaryArithmetic(name, build(e.children()[0])); + case "CASE_WHEN": + List children = new ArrayList<>(); + for (Expression child : e.children()) { + children.add(build(child)); + } + return visitCaseWhen(children.toArray(new String[e.children().length])); + // TODO supports other expressions + default: + return visitUnexpectedExpr(expr); + } + } else { + return visitUnexpectedExpr(expr); + } + } + + protected String visitLiteral(LiteralValue literalValue) { + return literalValue.toString(); + } + + protected String visitFieldReference(FieldReference fieldRef) { + return fieldRef.toString(); + } + + protected String visitIsNull(String v) { + return v + " IS NULL"; + } + + protected String visitIsNotNull(String v) { + return v + " IS NOT NULL"; + } + + protected String visitBinaryComparison(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitBinaryArithmetic(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitAnd(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitOr(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitNot(String v) { + return "NOT (" + v + ")"; + } + + protected String visitUnaryArithmetic(String name, String v) { return name +" (" + v + ")"; } + + protected String visitCaseWhen(String[] children) { + StringBuilder sb = new StringBuilder("CASE"); + for (int i = 0; i < children.length; i += 2) { + String c = children[i]; + int j = i + 1; + if (j < children.length) { + String v = children[j]; + sb.append(" WHEN "); + sb.append(c); + sb.append(" THEN "); + sb.append(v); + } else { + sb.append(" ELSE "); + sb.append(c); + } + } + sb.append(" END"); + return sb.toString(); + } + + protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { + throw new IllegalArgumentException("Unexpected V2 expression: " + expr); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala deleted file mode 100644 index 6239d0e2e7ae8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryOperator, CaseWhen, EqualTo, Expression, IsNotNull, IsNull, Literal, Not} -import org.apache.spark.sql.connector.expressions.LiteralValue - -/** - * The builder to generate SQL string from catalyst expressions. - */ -class ExpressionSQLBuilder(e: Expression) { - - def build(): Option[String] = generateSQL(e) - - private def generateSQL(expr: Expression): Option[String] = expr match { - case Literal(value, dataType) => Some(LiteralValue(value, dataType).toString) - case a: Attribute => Some(quoteIfNeeded(a.name)) - case IsNull(col) => generateSQL(col).map(c => s"$c IS NULL") - case IsNotNull(col) => generateSQL(col).map(c => s"$c IS NOT NULL") - case b: BinaryOperator => - val l = generateSQL(b.left) - val r = generateSQL(b.right) - if (l.isDefined && r.isDefined) { - Some(s"(${l.get}) ${b.sqlOperator} (${r.get})") - } else { - None - } - case Not(EqualTo(left, right)) => - val l = generateSQL(left) - val r = generateSQL(right) - if (l.isDefined && r.isDefined) { - Some(s"${l.get} != ${r.get}") - } else { - None - } - case Not(child) => generateSQL(child).map(v => s"NOT ($v)") - case CaseWhen(branches, elseValue) => - val conditionsSQL = branches.map(_._1).flatMap(generateSQL) - val valuesSQL = branches.map(_._2).flatMap(generateSQL) - if (conditionsSQL.length == branches.length && valuesSQL.length == branches.length) { - val branchSQL = - conditionsSQL.zip(valuesSQL).map { case (c, v) => s" WHEN $c THEN $v" }.mkString - if (elseValue.isDefined) { - elseValue.flatMap(generateSQL).map(v => s"CASE$branchSQL ELSE $v END") - } else { - Some(s"CASE$branchSQL END") - } - } else { - None - } - // TODO supports other expressions - case _ => None - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala new file mode 100644 index 0000000000000..1e361695056a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} + +/** + * The builder to generate V2 expressions from catalyst expressions. + */ +class V2ExpressionBuilder(e: Expression) { + + def build(): Option[V2Expression] = generateExpression(e) + + private def canTranslate(b: BinaryOperator) = b match { + case _: And | _: Or => true + case _: BinaryComparison => true + case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true + case add: Add => add.failOnError + case sub: Subtract => sub.failOnError + case mul: Multiply => mul.failOnError + case div: Divide => div.failOnError + case r: Remainder => r.failOnError + case _ => false + } + + private def generateExpression(expr: Expression): Option[V2Expression] = expr match { + case Literal(value, dataType) => Some(LiteralValue(value, dataType)) + case attr: Attribute => Some(FieldReference.column(attr.name)) + case IsNull(col) => generateExpression(col) + .map(c => new GeneralScalarExpression("IS_NULL", Array[V2Expression](c))) + case IsNotNull(col) => generateExpression(col) + .map(c => new GeneralScalarExpression("IS_NOT_NULL", Array[V2Expression](c))) + case b: BinaryOperator if canTranslate(b) => + val left = generateExpression(b.left) + val right = generateExpression(b.right) + if (left.isDefined && right.isDefined) { + Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](left.get, right.get))) + } else { + None + } + case Not(eq: EqualTo) => + val left = generateExpression(eq.left) + val right = generateExpression(eq.right) + if (left.isDefined && right.isDefined) { + Some(new GeneralScalarExpression("!=", Array[V2Expression](left.get, right.get))) + } else { + None + } + case Not(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("NOT", Array[V2Expression](v))) + case UnaryMinus(child, true) => generateExpression(child) + .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) + case BitwiseNot(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) + case CaseWhen(branches, elseValue) => + val conditions = branches.map(_._1).flatMap(generateExpression) + val values = branches.map(_._2).flatMap(generateExpression) + if (conditions.length == branches.length && values.length == branches.length) { + val branchExpressions = conditions.zip(values).flatMap { case (c, v) => + Seq[V2Expression](c, v) + } + if (elseValue.isDefined) { + elseValue.flatMap(generateExpression).map { v => + val children = (branchExpressions :+ v).toArray[V2Expression] + // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue] + new GeneralScalarExpression("CASE_WHEN", children) + } + } else { + // The children looks like [condition1, value1, ..., conditionN, valueN] + Some(new GeneralScalarExpression("CASE_WHEN", branchExpressions.toArray[V2Expression])) + } + } else { + None + } + // TODO supports other expressions + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 29c73ba0cf59c..5d0aecb94264d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -38,10 +38,10 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.catalyst.util.ExpressionSQLBuilder +import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{Expression => ExpressionV2, FieldReference, GeneralSQLExpression, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -770,8 +770,8 @@ object DataSourceStrategy Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)) } - protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { - def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = { + def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match { case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => val directionV2 = directionV1 match { case Ascending => SortDirection.ASCENDING @@ -858,8 +858,8 @@ object PushableColumnWithoutNestedColumn extends PushableColumnBase { * Get the expression of DS V2 to represent catalyst expression that can be pushed down. */ object PushableExpression { - def unapply(e: Expression): Option[ExpressionV2] = e match { + def unapply(e: Expression): Option[V2Expression] = e match { case PushableColumnWithoutNestedColumn(name) => Some(FieldReference(name)) - case _ => new ExpressionSQLBuilder(e).build().map(new GeneralSQLExpression(_)) + case _ => new V2ExpressionBuilder(e).build() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 2d10bbf5de537..a7e0ec8b72a7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, Timesta import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{FieldReference, GeneralSQLExpression, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -194,6 +195,31 @@ abstract class JdbcDialect extends Serializable with Logging{ case _ => value } + class JDBCSQLBuilder extends V2ExpressionSQLBuilder { + override def visitFieldReference(fieldRef: FieldReference): String = { + if (fieldRef.fieldNames().length != 1) { + throw new IllegalArgumentException( + "FieldReference with field name has multiple or zero parts unsupported: " + fieldRef); + } + quoteIdentifier(fieldRef.fieldNames.head) + } + } + + /** + * Converts V2 expression to String representing a SQL expression. + * @param expr The V2 expression to be converted. + * @return Converted value. + */ + @Since("3.3.0") + def compileExpression(expr: Expression): Option[String] = { + val jdbcSQLBuilder = new JDBCSQLBuilder() + try { + Some(jdbcSQLBuilder.build(expr)) + } catch { + case _: IllegalArgumentException => None + } + } + /** * Converts aggregate function to String representing a SQL expression. * @param aggFunction The aggregate function to be converted. @@ -203,55 +229,20 @@ abstract class JdbcDialect extends Serializable with Logging{ def compileAggregate(aggFunction: AggregateFunc): Option[String] = { aggFunction match { case min: Min => - val sql = min.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } - Some(s"MIN($sql)") + compileExpression(min.column).map(v => s"MIN($v)") case max: Max => - val sql = max.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } - Some(s"MAX($sql)") + compileExpression(max.column).map(v => s"MAX($v)") case count: Count => - val sql = count.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } val distinct = if (count.isDistinct) "DISTINCT " else "" - Some(s"COUNT($distinct$sql)") + compileExpression(count.column).map(v => s"COUNT($distinct$v)") case sum: Sum => - val sql = sum.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } val distinct = if (sum.isDistinct) "DISTINCT " else "" - Some(s"SUM($distinct$sql)") + compileExpression(sum.column).map(v => s"SUM($distinct$v)") case _: CountStar => Some("COUNT(*)") case avg: Avg => - val sql = avg.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } val distinct = if (avg.isDistinct) "DISTINCT " else "" - Some(s"AVG($distinct$sql)") + compileExpression(avg.column).map(v => s"AVG($distinct$v)") case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 91ac7db335cc3..e9c8131fe9bec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} @@ -97,7 +95,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): InMemoryTable = { + properties: java.util.Map[String, String]): InMemoryTable = { new InMemoryTable(name, schema, partitions, properties) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index d5417be0f229f..e4ba33c619a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector -import java.util import java.util.Collections import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen} @@ -35,7 +34,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { - private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String] private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = { catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 2db9f0583a2ab..be4dc1eedb59d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector import java.io.File -import java.util import java.util.OptionalLong import scala.collection.JavaConverters._ @@ -542,7 +541,7 @@ abstract class SimpleBatchTable extends Table with SupportsRead { override def name(): String = this.getClass.toString - override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava + override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ) } abstract class SimpleScanBuilder extends ScanBuilder @@ -565,7 +564,7 @@ trait TestingV2Source extends TableProvider { override def getTable( schema: StructType, partitioning: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { getTable(new CaseInsensitiveStringMap(properties)) } @@ -782,7 +781,7 @@ class SchemaRequiredDataSource extends TableProvider { override def getTable( schema: StructType, partitioning: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala index db71eeb75eae0..e3d61a846fdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.connector -import java.util - -import scala.collection.JavaConverters._ - import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} @@ -63,7 +59,7 @@ class TestLocalScanCatalog extends BasicInMemoryTableCatalog { ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val table = new TestLocalScanTable(ident.toString) tables.put(ident, table) table @@ -78,7 +74,8 @@ object TestLocalScanTable { class TestLocalScanTable(override val name: String) extends Table with SupportsRead { override def schema(): StructType = TestLocalScanTable.schema - override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new TestLocalScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index bb2acecc782b2..64c893ed74fdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util import scala.collection.JavaConverters._ @@ -138,8 +137,8 @@ class SimpleWritableDataSource extends TestingV2Source { new MyWriteBuilder(path, info) } - override def capabilities(): util.Set[TableCapability] = - Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE) } override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index ce94d3b5c2fc0..5f2e0b28aeccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.connector -import java.util - -import scala.collection.JavaConverters._ - import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} @@ -217,7 +213,11 @@ private case object TestRelation extends LeafNode with NamedRelation { private case class CapabilityTable(_capabilities: TableCapability*) extends Table { override def name(): String = "capability_test_table" override def schema(): StructType = TableCapabilityCheckSuite.schema - override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava + override def capabilities(): java.util.Set[TableCapability] = { + val set = java.util.EnumSet.noneOf(classOf[TableCapability]) + _capabilities.foreach(set.add) + set + } } private class TestStreamSourceProvider extends StreamSourceProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index bf2749d1afc53..0a0aaa8021996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector -import java.util import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean @@ -35,7 +34,7 @@ import org.apache.spark.sql.types.StructType */ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends DelegatingCatalogExtension { - protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() + protected val tables: java.util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() private val tableCreated: AtomicBoolean = new AtomicBoolean(false) @@ -48,7 +47,7 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): T + properties: java.util.Map[String, String]): T override def loadTable(ident: Identifier): Table = { if (tables.containsKey(ident)) { @@ -69,12 +68,12 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY val propsWithLocation = if (properties.containsKey(key)) { // Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified. if (!properties.containsKey(TableCatalog.PROP_LOCATION)) { - val newProps = new util.HashMap[String, String]() + val newProps = new java.util.HashMap[String, String]() newProps.putAll(properties) newProps.put(TableCatalog.PROP_LOCATION, "file:/abc") newProps diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index 847953e09cef7..c5be222645b19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.connector -import java.util - -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext} import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} @@ -106,7 +102,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { // To simplify the test implementation, only support fixed schema. if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) { throw new UnsupportedOperationException @@ -131,8 +127,8 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp override def schema(): StructType = V1ReadFallbackCatalog.schema - override def capabilities(): util.Set[TableCapability] = { - Set(TableCapability.BATCH_READ).asJava + override def capabilities(): java.util.Set[TableCapability] = { + java.util.EnumSet.of(TableCapability.BATCH_READ) } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 7effc747ab323..992c46cc6cdb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -223,7 +221,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { + properties: java.util.Map[String, String]): InMemoryTableWithV1Fallback = { val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) InMemoryV1Provider.tables.put(name, t) tables.put(Identifier.of(Array("default"), name), t) @@ -321,7 +319,7 @@ class InMemoryTableWithV1Fallback( override val name: String, override val schema: StructType, override val partitioning: Array[Transform], - override val properties: util.Map[String, String]) + override val properties: java.util.Map[String, String]) extends Table with SupportsWrite with SupportsRead { @@ -331,11 +329,11 @@ class InMemoryTableWithV1Fallback( } } - override def capabilities: util.Set[TableCapability] = Set( + override def capabilities: java.util.Set[TableCapability] = java.util.EnumSet.of( TableCapability.BATCH_READ, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, - TableCapability.TRUNCATE).asJava + TableCapability.TRUNCATE) @volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index c6d0a87787639..5a7e1c00494a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager} import java.util.Properties -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{avg, count, lit, sum, udf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -852,6 +853,34 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Row(2, 2, 2, 2, 2, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 3, 0d))) } + test("scan with aggregate push-down: aggregate function with binary arithmetic") { + Seq(false, true).foreach { ansiMode => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") + checkAggregateRemoved(df, ansiMode) + val expected_plan_fragment = if (ansiMode) { + "PushedAggregates: [SUM((2147483647) + (DEPT))], " + + "PushedFilters: [], PushedGroupByColumns: []" + } else { + "PushedFilters: []" + } + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + if (ansiMode) { + val e = intercept[SparkException] { + checkAnswer(df, Seq(Row(-10737418233L))) + } + assert(e.getMessage.contains( + "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\"")) + } else { + checkAnswer(df, Seq(Row(-10737418233L))) + } + } + } + } + test("scan with aggregate push-down: aggregate function with UDF") { val df = spark.table("h2.test.employee") val decrease = udf { (x: Double, y: Double) => x - y } From ed0d635b0449919fc5f25853d3dcc1dca10c3d76 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 8 Mar 2022 16:38:23 +0800 Subject: [PATCH 34/53] [SPARK-38361][SQL] Add factory method `getConnection` into `JDBCDialect` ### What changes were proposed in this pull request? At present, the parameter of the factory method for obtaining JDBC connection is empty because the JDBC URL of some databases is fixed and unique. However, for databases such as ClickHouse, connection is related to the shard node. So I think the parameter form of `getConnection: Partition = > Connection` is more general. This PR adds factory method `getConnection` into `JDBCDialect` according to https://github.com/apache/spark/pull/35696#issuecomment-1058060107. ### Why are the changes needed? Make factory method `getConnection` more general. ### Does this PR introduce _any_ user-facing change? 'No'. Just inner change. ### How was this patch tested? Exists test. Closes #35727 from beliefer/SPARK-38361_new. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../execution/datasources/jdbc/JDBCRDD.scala | 8 ++--- .../jdbc/JdbcRelationProvider.scala | 5 +-- .../datasources/jdbc/JdbcUtils.scala | 32 ++++--------------- .../jdbc/connection/ConnectionProvider.scala | 2 ++ .../v2/jdbc/JDBCWriteBuilder.scala | 4 ++- .../apache/spark/sql/jdbc/JdbcDialects.scala | 28 ++++++++++++++-- 6 files changed, 44 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index baee53847a5a4..b5224eaf7262b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -60,7 +60,7 @@ object JDBCRDD extends Logging { def getQueryOutputSchema( query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = { - val conn: Connection = JdbcUtils.createConnectionFactory(options)() + val conn: Connection = dialect.createConnectionFactory(options)(-1) try { val statement = conn.prepareStatement(query) try { @@ -182,7 +182,7 @@ object JDBCRDD extends Logging { } new JDBCRDD( sc, - JdbcUtils.createConnectionFactory(options), + dialect.createConnectionFactory(options), outputSchema.getOrElse(pruneSchema(schema, requiredColumns)), quotedColumns, filters, @@ -204,7 +204,7 @@ object JDBCRDD extends Logging { */ private[jdbc] class JDBCRDD( sc: SparkContext, - getConnection: () => Connection, + getConnection: Int => Connection, schema: StructType, columns: Array[String], filters: Array[Filter], @@ -318,7 +318,7 @@ private[jdbc] class JDBCRDD( val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] - conn = getConnection() + conn = getConnection(part.idx) val dialect = JdbcDialects.get(url) import scala.collection.JavaConverters._ dialect.beforeFetch(conn, options.asProperties.asScala.toMap) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index d953ba45cc2fb..2760c7ac3019c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} class JdbcRelationProvider extends CreatableRelationProvider @@ -45,8 +46,8 @@ class JdbcRelationProvider extends CreatableRelationProvider df: DataFrame): BaseRelation = { val options = new JdbcOptionsInWrite(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis - - val conn = JdbcUtils.createConnectionFactory(options)() + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) try { val tableExists = JdbcUtils.tableExists(conn, options) if (tableExists) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index da627a16099c1..100eb23a4bbf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} import java.time.{Instant, LocalDate} import java.util import java.util.Locale @@ -43,7 +43,6 @@ import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ @@ -54,24 +53,7 @@ import org.apache.spark.util.NextIterator /** * Util functions for JDBC tables. */ -object JdbcUtils extends Logging { - /** - * Returns a factory for creating connections to the given JDBC URL. - * - * @param options - JDBC options that contains url, table and other information. - * @throws IllegalArgumentException if the driver could not open a JDBC connection. - */ - def createConnectionFactory(options: JDBCOptions): () => Connection = { - val driverClass: String = options.driverClass - () => { - DriverRegistry.register(driverClass) - val driver: Driver = DriverRegistry.get(driverClass) - val connection = ConnectionProvider.create(driver, options.parameters) - require(connection != null, - s"The driver could not open a JDBC connection. Check the URL: ${options.url}") - connection - } - } +object JdbcUtils extends Logging with SQLConfHelper { /** * Returns true if the table already exists in the JDBC database. @@ -656,7 +638,6 @@ object JdbcUtils extends Logging { * updated even with error if it doesn't support transaction, as there're dirty outputs. */ def savePartition( - getConnection: () => Connection, table: String, iterator: Iterator[Row], rddSchema: StructType, @@ -667,7 +648,7 @@ object JdbcUtils extends Logging { options: JDBCOptions): Unit = { val outMetrics = TaskContext.get().taskMetrics().outputMetrics - val conn = getConnection() + val conn = dialect.createConnectionFactory(options)(-1) var committed = false var finalIsolationLevel = Connection.TRANSACTION_NONE @@ -879,7 +860,6 @@ object JdbcUtils extends Logging { val table = options.table val dialect = JdbcDialects.get(url) val rddSchema = df.schema - val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel @@ -891,8 +871,7 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.rdd.foreachPartition { iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, - options) + table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options) } } @@ -1182,7 +1161,8 @@ object JdbcUtils extends Logging { } def withConnection[T](options: JDBCOptions)(f: Connection => T): T = { - val conn = createConnectionFactory(options)() + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) try { f(conn) } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala index fbc69704f1479..f84fdb90dd55e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala @@ -73,3 +73,5 @@ private[jdbc] object ConnectionProvider extends Logging { } } } + +private[sql] object ConnectionProvider extends ConnectionProviderBase diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala index 0e6c72c2cc331..7449f66ee020f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala @@ -20,6 +20,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.execution.datasources.jdbc.{JdbcOptionsInWrite, JdbcUtils} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.StructType @@ -37,7 +38,8 @@ case class JDBCWriteBuilder(schema: StructType, options: JdbcOptionsInWrite) ext override def toInsertableRelation: InsertableRelation = (data: DataFrame, _: Boolean) => { // TODO (SPARK-32595): do truncate and append atomically. if (isTruncate) { - val conn = JdbcUtils.createConnectionFactory(options)() + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) JdbcUtils.truncateTable(conn, options) } JdbcUtils.saveTable(data, Some(schema), SQLConf.get.caseSensitiveAnalysis, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a7e0ec8b72a7c..c9dcbb2706cd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, Date, Statement, Timestamp} +import java.sql.{Connection, Date, Driver, Statement, Timestamp} import java.time.{Instant, LocalDate} import java.util @@ -36,7 +36,8 @@ import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, N import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -100,6 +101,29 @@ abstract class JdbcDialect extends Serializable with Logging{ */ def getJDBCType(dt: DataType): Option[JdbcType] = None + /** + * Returns a factory for creating connections to the given JDBC URL. + * In general, creating a connection has nothing to do with JDBC partition id. + * But sometimes it is needed, such as a database with multiple shard nodes. + * @param options - JDBC options that contains url, table and other information. + * @return The factory method for creating JDBC connections with the RDD partition ID. -1 means + the connection is being created at the driver side. + * @throws IllegalArgumentException if the driver could not open a JDBC connection. + */ + @Since("3.3.0") + def createConnectionFactory(options: JDBCOptions): Int => Connection = { + val driverClass: String = options.driverClass + (partitionId: Int) => { + DriverRegistry.register(driverClass) + val driver: Driver = DriverRegistry.get(driverClass) + val connection = + ConnectionProvider.create(driver, options.parameters, options.connectionProviderName) + require(connection != null, + s"The driver could not open a JDBC connection. Check the URL: ${options.url}") + connection + } + } + /** * Quotes the identifier. This is used to put quotes around the identifier in case the column * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). From ea792859c4fb0f7f2c80293e68337ed25eb135ec Mon Sep 17 00:00:00 2001 From: chenzhx Date: Wed, 9 Mar 2022 17:19:43 +0800 Subject: [PATCH 35/53] code format --- .../apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala | 2 +- .../apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 2 +- .../datasources/jdbc/connection/ConnectionProvider.scala | 2 +- .../src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../org/apache/spark/sql/connector/DataSourceV2Suite.scala | 2 -- 5 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 1e361695056a7..2ffae68284cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -41,7 +41,7 @@ class V2ExpressionBuilder(e: Expression) { private def generateExpression(expr: Expression): Option[V2Expression] = expr match { case Literal(value, dataType) => Some(LiteralValue(value, dataType)) - case attr: Attribute => Some(FieldReference.column(attr.name)) + case attr: Attribute => Some(FieldReference(attr.name)) case IsNull(col) => generateExpression(col) .map(c => new GeneralScalarExpression("IS_NULL", Array[V2Expression](c))) case IsNotNull(col) => generateExpression(col) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 100eb23a4bbf5..2d0cbcff8ecc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -53,7 +53,7 @@ import org.apache.spark.util.NextIterator /** * Util functions for JDBC tables. */ -object JdbcUtils extends Logging with SQLConfHelper { +object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala index f84fdb90dd55e..ed8398f265848 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcConnectionProvider import org.apache.spark.util.Utils -private[jdbc] object ConnectionProvider extends Logging { +protected abstract class ConnectionProviderBase extends Logging { private val providers = loadProviders() def loadProviders(): Seq[JdbcConnectionProvider] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c9dcbb2706cd4..e886d8b8deae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -117,7 +117,7 @@ abstract class JdbcDialect extends Serializable with Logging{ DriverRegistry.register(driverClass) val driver: Driver = DriverRegistry.get(driverClass) val connection = - ConnectionProvider.create(driver, options.parameters, options.connectionProviderName) + ConnectionProvider.create(driver, options.parameters) require(connection != null, s"The driver could not open a JDBC connection. Check the URL: ${options.url}") connection diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index be4dc1eedb59d..1f19836834171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.connector import java.io.File import java.util.OptionalLong -import scala.collection.JavaConverters._ - import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkException From e16fe8aa645d06972aca1ca1cbf3769856ad8eee Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 17 Mar 2022 16:53:40 +0800 Subject: [PATCH 36/53] [SPARK-38560][SQL] If `Sum`, `Count`, `Any` accompany with distinct, cannot do partial agg push down ### What changes were proposed in this pull request? Spark could partial push down sum(distinct col), count(distinct col) if data source have multiple partitions, and Spark will sum the value again. So the result may not correctly. ### Why are the changes needed? Fix the bug push down sum(distinct col), count(distinct col) to data source and return incorrect result. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users will see the correct behavior. ### How was this patch tested? New tests. Closes #35873 from beliefer/SPARK-38560. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../v2/V2ScanRelationPushDown.scala | 184 ++++++++++-------- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 14 +- 2 files changed, 111 insertions(+), 87 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 05857c545cdf6..b97823fcd09e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.SortOrder -import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, GeneralAggregateFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources @@ -147,101 +147,106 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } - val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) - if (pushedAggregates.isEmpty) { + if (finalTranslatedAggregates.isEmpty) { aggNode // return original plan node - } else if (!supportPartialAggPushDown(pushedAggregates.get) && - !r.supportCompletePushDown(pushedAggregates.get)) { + } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) && + !supportPartialAggPushDown(finalTranslatedAggregates.get)) { aggNode // return original plan node } else { - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. - val scan = sHolder.builder.build() - - // scalastyle:off - // use the group by columns and aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] - // scalastyle:on - val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + finalAggregates.length) - val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { - case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) - case (_, b) => b - } - val aggOutput = newOutput.drop(groupAttrs.length) - val output = groupAttrs ++ aggOutput - - logInfo( - s""" - |Pushing operators to ${sHolder.relation.name} - |Pushed Aggregate Functions: - | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} - |Pushed Group by: - | ${pushedAggregates.get.groupByColumns.mkString(", ")} - |Output: ${output.mkString(", ")} - """.stripMargin) - - val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) - val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = resultExpressions.map { expr => - // TODO At present, only push down group by attribute is supported. - // In future, more attribute conversion is extended here. e.g. GetStructField - expr.transform { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val child = - addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) - Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) - } - }.asInstanceOf[Seq[NamedExpression]] - Project(projectExpressions, scanRelation) + val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) + if (pushedAggregates.isEmpty) { + aggNode // return original plan node } else { - val plan = Aggregate( - output.take(groupingExpressions.length), finalResultExpressions, scanRelation) + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. + val scan = sHolder.builder.build() // scalastyle:off - // Change the optimized logical plan to reflect the pushed down aggregate + // use the group by columns and aggregate columns as the output columns // e.g. TABLE t (c1 INT, c2 INT, c3 INT) // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9, c2#10] ... - // - // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] - // we have the following - // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // - // We want to change it to + // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: // == Optimized Logical Plan == // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] // scalastyle:on - plan.transformExpressions { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val aggAttribute = aggOutput(ordinal) - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case max: aggregate.Max => - max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType)) - case min: aggregate.Min => - min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType)) - case sum: aggregate.Sum => - sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType)) - case _: aggregate.Count => - aggregate.Sum(addCastIfNeeded(aggAttribute, LongType)) - case other => other - } - agg.copy(aggregateFunction = aggFunction) + val newOutput = scan.readSchema().toAttributes + assert(newOutput.length == groupingExpressions.length + finalAggregates.length) + val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { + case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) + case (_, b) => b + } + val aggOutput = newOutput.drop(groupAttrs.length) + val output = groupAttrs ++ aggOutput + + logInfo( + s""" + |Pushing operators to ${sHolder.relation.name} + |Pushed Aggregate Functions: + | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} + |Pushed Group by: + | ${pushedAggregates.get.groupByColumns.mkString(", ")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) + val scanRelation = + DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) + if (r.supportCompletePushDown(pushedAggregates.get)) { + val projectExpressions = resultExpressions.map { expr => + // TODO At present, only push down group by attribute is supported. + // In future, more attribute conversion is extended here. e.g. GetStructField + expr.transform { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val child = + addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) + Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) + } + }.asInstanceOf[Seq[NamedExpression]] + Project(projectExpressions, scanRelation) + } else { + val plan = Aggregate(output.take(groupingExpressions.length), + finalResultExpressions, scanRelation) + + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + plan.transformExpressions { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val aggAttribute = aggOutput(ordinal) + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case max: aggregate.Max => + max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType)) + case min: aggregate.Min => + min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType)) + case sum: aggregate.Sum => + sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType)) + case _: aggregate.Count => + aggregate.Sum(addCastIfNeeded(aggAttribute, LongType)) + case other => other + } + agg.copy(aggregateFunction = aggFunction) + } } } } @@ -270,7 +275,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { private def supportPartialAggPushDown(agg: Aggregation): Boolean = { // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. - agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) + // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down. + agg.aggregateExpressions().exists { + case sum: Sum => !sum.isDistinct + case count: Count => !count.isDistinct + case avg: Avg => !avg.isDistinct + case _: GeneralAggregateFunc => false + case _ => true + } } private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5a7e1c00494a5..6d0acdc700723 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{avg, count, lit, sum, udf} +import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, sum, udf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -509,6 +509,18 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(3))) } + test("scan with aggregate push-down: cannot partial push down COUNT(DISTINCT col)") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .agg(count_distinct($"DEPT")) + checkAggregateRemoved(df, false) + checkAnswer(df, Seq(Row(3))) + } + test("scan with aggregate push-down: SUM without filer and group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee") checkAggregateRemoved(df) From 0fce03d81b06a730fb2d2f15a81d5dbd34e1ad2b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 17 Sep 2021 21:34:21 +0800 Subject: [PATCH 37/53] [SPARK-36718][SQL] Only collapse projects if we don't duplicate expensive expressions ### What changes were proposed in this pull request? The `CollapseProject` rule can combine adjacent projects and merge the project lists. The key idea behind this rule is that the evaluation of project is relatively expensive, and that expression evaluation is cheap and that the expression duplication caused by this rule is not a problem. This last assumption is, unfortunately, not always true: - A user can invoke some expensive UDF, this now gets invoked more often than originally intended. - A projection is very cheap in whole stage code generation. The duplication caused by `CollapseProject` does more harm than good here. This PR addresses this problem, by only collapsing projects when it does not duplicate expensive expressions. In practice this means an input reference may only be consumed once, or when its evaluation does not incur significant overhead (currently attributes, nested column access, aliases & literals fall in this category). ### Why are the changes needed? We have seen multiple complains about `CollapseProject` in the past, due to it may duplicate expensive expressions. The most recent one is https://github.com/apache/spark/pull/33903 . ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? a new UT and existing test Closes #33958 from cloud-fan/collapse. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../catalyst/expressions/AliasHelper.scala | 8 +- .../sql/catalyst/optimizer/Optimizer.scala | 123 ++++++++---- .../sql/catalyst/planning/patterns.scala | 180 +++++++----------- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../optimizer/CollapseProjectSuite.scala | 10 + .../planning/ScanOperationSuite.scala | 28 ++- 6 files changed, 197 insertions(+), 159 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index 0007d3868eda2..dea7ea0f144bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.MultiAlias import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.types.Metadata /** * Helper methods for collecting and replacing aliases. @@ -86,10 +87,15 @@ trait AliasHelper { protected def trimNonTopLevelAliases[T <: Expression](e: T): T = { val res = e match { case a: Alias => + val metadata = if (a.metadata == Metadata.empty) { + None + } else { + Some(a.metadata) + } a.copy(child = trimAliases(a.child))( exprId = a.exprId, qualifier = a.qualifier, - explicitMetadata = Some(a.metadata), + explicitMetadata = metadata, nonInheritableMetadataKeys = a.nonInheritableMetadataKeys) case a: MultiAlias => a.copy(child = trimAliases(a.child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index eb040e23290c9..1a57ee83fa3ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -912,56 +912,69 @@ object ColumnPruning extends Rule[LogicalPlan] { */ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( - _.containsPattern(PROJECT), ruleId) { - case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { - p1 - } else { + def apply(plan: LogicalPlan): LogicalPlan = { + val alwaysInline = conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + plan.transformUpWithPruning(_.containsPattern(PROJECT), ruleId) { + case p1 @ Project(_, p2: Project) + if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) => p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) - } - case p @ Project(_, agg: Aggregate) => - if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions) || - !canCollapseAggregate(p, agg)) { - p - } else { + case p @ Project(_, agg: Aggregate) + if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) => agg.copy(aggregateExpressions = buildCleanedProjectList( p.projectList, agg.aggregateExpressions)) - } - case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) + case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) if isRenaming(l1, l2) => - val newProjectList = buildCleanedProjectList(l1, l2) - g.copy(child = limit.copy(child = p2.copy(projectList = newProjectList))) - case Project(l1, limit @ LocalLimit(_, p2 @ Project(l2, _))) if isRenaming(l1, l2) => - val newProjectList = buildCleanedProjectList(l1, l2) - limit.copy(child = p2.copy(projectList = newProjectList)) - case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if isRenaming(l1, l2) => - r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, p.projectList))) - case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) => - s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList))) - } - - private def haveCommonNonDeterministicOutput( - upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { - val aliases = getAliasMap(lower) + val newProjectList = buildCleanedProjectList(l1, l2) + g.copy(child = limit.copy(child = p2.copy(projectList = newProjectList))) + case Project(l1, limit @ LocalLimit(_, p2 @ Project(l2, _))) if isRenaming(l1, l2) => + val newProjectList = buildCleanedProjectList(l1, l2) + limit.copy(child = p2.copy(projectList = newProjectList)) + case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if isRenaming(l1, l2) => + r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, p.projectList))) + case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) => + s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList))) + } + } - // Collapse upper and lower Projects if and only if their overlapped expressions are all - // deterministic. - upper.exists(_.collect { - case a: Attribute if aliases.contains(a) => aliases(a).child - }.exists(!_.deterministic)) + /** + * Check if we can collapse expressions safely. + */ + def canCollapseExpressions( + consumers: Seq[Expression], + producers: Seq[NamedExpression], + alwaysInline: Boolean): Boolean = { + canCollapseExpressions(consumers, getAliasMap(producers), alwaysInline) } /** - * A project cannot be collapsed with an aggregate when there are correlated scalar - * subqueries in the project list, because currently we only allow correlated subqueries - * in aggregate if they are also part of the grouping expressions. Otherwise the plan - * after subquery rewrite will not be valid. + * Check if we can collapse expressions safely. */ - private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = { - p.projectList.forall(_.collect { - case s: ScalarSubquery if s.outerAttrs.nonEmpty => s - }.isEmpty) + def canCollapseExpressions( + consumers: Seq[Expression], + producerMap: Map[Attribute, Expression], + alwaysInline: Boolean = false): Boolean = { + // We can only collapse expressions if all input expressions meet the following criteria: + // - The input is deterministic. + // - The input is only consumed once OR the underlying input expression is cheap. + consumers.flatMap(collectReferences) + .groupBy(identity) + .mapValues(_.size) + .forall { + case (reference, count) => + val producer = producerMap.getOrElse(reference, reference) + producer.deterministic && (count == 1 || alwaysInline || { + val relatedConsumers = consumers.filter(_.references.contains(reference)) + val extractOnly = relatedConsumers.forall(isExtractOnly(_, reference)) + shouldInline(producer, extractOnly) + }) + } + } + + private def isExtractOnly(expr: Expression, ref: Attribute): Boolean = expr match { + case a: Alias => isExtractOnly(a.child, ref) + case e: ExtractValue => isExtractOnly(e.children.head, ref) + case a: Attribute => a.semanticEquals(ref) + case _ => false } private def buildCleanedProjectList( @@ -971,6 +984,34 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { upper.map(replaceAliasButKeepName(_, aliases)) } + /** + * Check if the given expression is cheap that we can inline it. + */ + private def shouldInline(e: Expression, extractOnlyConsumer: Boolean): Boolean = e match { + case _: Attribute | _: OuterReference => true + case _ if e.foldable => true + // PythonUDF is handled by the rule ExtractPythonUDFs + case _: PythonUDF => true + // Alias and ExtractValue are very cheap. + case _: Alias | _: ExtractValue => e.children.forall(shouldInline(_, extractOnlyConsumer)) + // These collection create functions are not cheap, but we have optimizer rules that can + // optimize them out if they are only consumed by ExtractValue, so we need to allow to inline + // them to avoid perf regression. As an example: + // Project(s.a, s.b, Project(create_struct(a, b, c) as s, child)) + // We should collapse these two projects and eventually get Project(a, b, child) + case _: CreateNamedStruct | _: CreateArray | _: CreateMap | _: UpdateFields => + extractOnlyConsumer + case _ => false + } + + /** + * Return all the references of the given expression without deduplication, which is different + * from `Expression.references`. + */ + private def collectReferences(e: Expression): Seq[Attribute] = e.collect { + case a: Attribute => a + } + private def isRenaming(list1: Seq[NamedExpression], list2: Seq[NamedExpression]): Boolean = { list1.length == list2.length && list1.zip(list2).forall { case (e1, e2) if e1.semanticEquals(e2) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index fc12f48ec2a11..f33d137ffd607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -26,46 +26,32 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -trait OperationHelper { - type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) - - protected def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] = - AttributeMap(fields.collect { - case a: Alias => (a.toAttribute, a.child) - }) - - protected def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = { - // use transformUp instead of transformDown to avoid dead loop - // in case of there's Alias whose exprId is the same as its child attribute. - expr.transformUp { - case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref) - .map(Alias(_, name)(a.exprId, a.qualifier)) - .getOrElse(a) - - case a: AttributeReference => - aliases.get(a) - .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) - } - } -} +trait OperationHelper extends AliasHelper with PredicateHelper { + import org.apache.spark.sql.catalyst.optimizer.CollapseProject.canCollapseExpressions -/** - * A pattern that matches any number of project or filter operations on top of another relational - * operator. All filter operators are collected and their conditions are broken up and returned - * together with the top project operator. - * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if - * necessary. - */ -object PhysicalOperation extends OperationHelper with PredicateHelper { + type ReturnType = + (Seq[NamedExpression], Seq[Expression], LogicalPlan) + type IntermediateType = + (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Alias]) def unapply(plan: LogicalPlan): Option[ReturnType] = { - val (fields, filters, child, _) = collectProjectsAndFilters(plan) + val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline) Some((fields.getOrElse(child.output), filters, child)) } /** - * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary. + * This legacy mode is for PhysicalOperation which has been there for years and we want to be + * extremely safe to not change its behavior. There are two differences when legacy mode is off: + * 1. We postpone the deterministic check to the very end (calling `canCollapseExpressions`), + * so that it's more likely to collect more projects and filters. + * 2. We follow CollapseProject and only collect adjacent projects if they don't produce + * repeated expensive expressions. + */ + protected def legacyMode: Boolean + + /** + * Collects all adjacent projects and filters, in-lining/substituting aliases if necessary. * Here are two examples for alias in-lining/substitution. * Before: * {{{ @@ -78,25 +64,60 @@ object PhysicalOperation extends OperationHelper with PredicateHelper { * SELECT key AS c2 FROM t1 WHERE key > 10 * }}} */ - private def collectProjectsAndFilters(plan: LogicalPlan): - (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Expression]) = + private def collectProjectsAndFilters( + plan: LogicalPlan, + alwaysInline: Boolean): IntermediateType = { + def empty: IntermediateType = (None, Nil, plan, AttributeMap.empty) + plan match { - case Project(fields, child) if fields.forall(_.deterministic) => - val (_, filters, other, aliases) = collectProjectsAndFilters(child) - val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] - (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) + case Project(fields, child) if !legacyMode || fields.forall(_.deterministic) => + val (_, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline) + if (legacyMode || canCollapseExpressions(fields, aliases, alwaysInline)) { + val replaced = fields.map(replaceAliasButKeepName(_, aliases)) + (Some(replaced), filters, other, getAliasMap(replaced)) + } else { + empty + } - case Filter(condition, child) if condition.deterministic => - val (fields, filters, other, aliases) = collectProjectsAndFilters(child) - val substitutedCondition = substitute(aliases)(condition) - (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + case Filter(condition, child) if !legacyMode || condition.deterministic => + val (fields, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline) + val canIncludeThisFilter = if (legacyMode) { + true + } else { + // When collecting projects and filters, we effectively push down filters through + // projects. We need to meet the following conditions to do so: + // 1) no Project collected so far or the collected Projects are all deterministic + // 2) the collected filters and this filter are all deterministic, or this is the + // first collected filter. + // 3) this filter does not repeat any expensive expressions from the collected + // projects. + fields.forall(_.forall(_.deterministic)) && { + filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic) + } && canCollapseExpressions(Seq(condition), aliases, alwaysInline) + } + if (canIncludeThisFilter) { + val replaced = replaceAlias(condition, aliases) + (fields, filters ++ splitConjunctivePredicates(replaced), other, aliases) + } else { + empty + } - case h: ResolvedHint => - collectProjectsAndFilters(h.child) + case h: ResolvedHint => collectProjectsAndFilters(h.child, alwaysInline) - case other => - (None, Nil, other, AttributeMap(Seq())) + case _ => empty } + } +} + +/** + * A pattern that matches any number of project or filter operations on top of another relational + * operator. All filter operators are collected and their conditions are broken up and returned + * together with the top project operator. + * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if + * necessary. + */ +object PhysicalOperation extends OperationHelper with PredicateHelper { + override protected def legacyMode: Boolean = true } /** @@ -105,70 +126,7 @@ object PhysicalOperation extends OperationHelper with PredicateHelper { * requirement of CollapseProject and CombineFilters. */ object ScanOperation extends OperationHelper with PredicateHelper { - type ScanReturnType = Option[(Option[Seq[NamedExpression]], - Seq[Expression], LogicalPlan, AttributeMap[Expression])] - - def unapply(plan: LogicalPlan): Option[ReturnType] = { - collectProjectsAndFilters(plan) match { - case Some((fields, filters, child, _)) => - Some((fields.getOrElse(child.output), filters, child)) - case None => None - } - } - - private def hasCommonNonDeterministic( - expr: Seq[Expression], - aliases: AttributeMap[Expression]): Boolean = { - expr.exists(_.collect { - case a: AttributeReference if aliases.contains(a) => aliases(a) - }.exists(!_.deterministic)) - } - - private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { - plan match { - case Project(fields, child) => - collectProjectsAndFilters(child) match { - case Some((_, filters, other, aliases)) => - // Follow CollapseProject and only keep going if the collected Projects - // do not have common non-deterministic expressions. - if (!hasCommonNonDeterministic(fields, aliases)) { - val substitutedFields = - fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] - Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) - } else { - None - } - case None => None - } - - case Filter(condition, child) => - collectProjectsAndFilters(child) match { - case Some((fields, filters, other, aliases)) => - // When collecting projects and filters, we effectively push down filters through - // projects. We need to meet the following conditions to do so: - // 1) no Project collected so far or the collected Projects are all deterministic - // 2) the collected filters and this filter are all deterministic, or this is the - // first collected filter. - val canCombineFilters = fields.forall(_.forall(_.deterministic)) && { - filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic) - } - val substitutedCondition = substitute(aliases)(condition) - if (canCombineFilters && !hasCommonNonDeterministic(Seq(condition), aliases)) { - Some((fields, filters ++ splitConjunctivePredicates(substitutedCondition), - other, aliases)) - } else { - None - } - case None => None - } - - case h: ResolvedHint => - collectProjectsAndFilters(h.child) - - case other => - Some((None, Nil, other, AttributeMap(Seq()))) - } - } + override protected def legacyMode: Boolean = false } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4518b9ddfc5ab..96ca754cad220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1868,6 +1868,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val COLLAPSE_PROJECT_ALWAYS_INLINE = buildConf("spark.sql.optimizer.collapseProjectAlwaysInline") + .doc("Whether to always collapse two adjacent projections and inline expressions even if " + + "it causes extra duplication.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion") .internal() .doc("Whether to delete the expired log files in file stream sink.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 1e7f9b0edd91c..c1d13d14b05f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -121,6 +121,16 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-36718: do not collapse project if non-cheap expressions will be repeated") { + val query = testRelation + .select(('a + 1).as('a_plus_1)) + .select(('a_plus_1 + 'a_plus_1).as('a_2_plus_2)) + .analyze + + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } + test("preserve top-level alias metadata while collapsing projects") { def hasMetadata(logicalPlan: LogicalPlan): Boolean = { logicalPlan.asInstanceOf[Project].projectList.exists(_.metadata.contains("key")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala index b1baeccbe94b9..eb3899c9187db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala @@ -57,7 +57,14 @@ class ScanOperationSuite extends SparkFunSuite { test("Project which has the same non-deterministic expression with its child Project") { val project3 = Project(Seq(colA, colR), Project(Seq(colA, aliasR), relation)) - assert(ScanOperation.unapply(project3).isEmpty) + project3 match { + case ScanOperation(projects, filters, _: Project) => + assert(projects.size === 2) + assert(projects(0) === colA) + assert(projects(1) === colR) + assert(filters.isEmpty) + case _ => assert(false) + } } test("Project which has different non-deterministic expressions with its child Project") { @@ -73,13 +80,18 @@ class ScanOperationSuite extends SparkFunSuite { test("Filter with non-deterministic Project") { val filter1 = Filter(EqualTo(colA, Literal(1)), Project(Seq(colA, aliasR), relation)) - assert(ScanOperation.unapply(filter1).isEmpty) + filter1 match { + case ScanOperation(projects, filters, _: Filter) => + assert(projects.size === 2) + assert(filters.isEmpty) + case _ => assert(false) + } } test("Non-deterministic Filter with deterministic Project") { - val filter3 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)), + val filter2 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)), Project(Seq(colA, colB), relation)) - filter3 match { + filter2 match { case ScanOperation(projects, filters, _: LocalRelation) => assert(projects.size === 2) assert(projects(0) === colA) @@ -91,7 +103,11 @@ class ScanOperationSuite extends SparkFunSuite { test("Deterministic filter which has a non-deterministic child Filter") { - val filter4 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation)) - assert(ScanOperation.unapply(filter4).isEmpty) + val filter3 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation)) + filter3 match { + case ScanOperation(projects, filters, _: Filter) => + assert(filters.isEmpty) + case _ => assert(false) + } } } From f113950182fd638000f72b172a8390cd6e2562aa Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 22 Mar 2022 15:45:51 +0800 Subject: [PATCH 38/53] [SPARK-38432][SQL] Refactor framework so as JDBC dialect could compile filter by self way ### What changes were proposed in this pull request? Currently, Spark DS V2 could push down filters into JDBC source. However, only the most basic form of filter is supported. On the other hand, some JDBC source could not compile the filters by themselves way. This PR reactor the framework so as JDBC dialect could compile expression by self way. First, The framework translate catalyst expressions to DS V2 filters. Second, The JDBC dialect could compile DS V2 filters to different SQL syntax. ### Why are the changes needed? Make the framework be more common use. ### Does this PR introduce _any_ user-facing change? 'No'. The feature is not released. ### How was this patch tested? Exists tests. Closes #35768 from beliefer/SPARK-38432_new. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- project/MimaExcludes.scala | 1787 +---------------- .../sql/connector/expressions/Expression.java | 17 + .../expressions/GeneralScalarExpression.java | 98 +- .../sql/connector/expressions/Literal.java | 3 + .../connector/expressions/NamedReference.java | 6 + .../sql/connector/expressions/SortOrder.java | 3 + .../sql/connector/expressions/Transform.java | 8 +- .../connector/expressions/aggregate/Avg.java | 3 + .../expressions/aggregate/Count.java | 3 + .../expressions/aggregate/CountStar.java | 4 + .../aggregate/GeneralAggregateFunc.java | 13 +- .../connector/expressions/aggregate/Max.java | 3 + .../connector/expressions/aggregate/Min.java | 3 + .../connector/expressions/aggregate/Sum.java | 3 + .../expressions/filter/AlwaysFalse.java | 30 +- .../expressions/filter/AlwaysTrue.java | 28 +- .../sql/connector/expressions/filter/And.java | 14 +- .../expressions/filter/BinaryComparison.java | 60 - .../expressions/filter/BinaryFilter.java | 65 - .../expressions/filter/EqualNullSafe.java | 40 - .../connector/expressions/filter/EqualTo.java | 39 - .../connector/expressions/filter/Filter.java | 40 - .../expressions/filter/GreaterThan.java | 39 - .../filter/GreaterThanOrEqual.java | 39 - .../sql/connector/expressions/filter/In.java | 76 - .../expressions/filter/IsNotNull.java | 58 - .../connector/expressions/filter/IsNull.java | 58 - .../expressions/filter/LessThan.java | 39 - .../expressions/filter/LessThanOrEqual.java | 39 - .../sql/connector/expressions/filter/Not.java | 31 +- .../sql/connector/expressions/filter/Or.java | 14 +- .../expressions/filter/Predicate.java | 149 ++ .../expressions/filter/StringContains.java | 39 - .../expressions/filter/StringEndsWith.java | 39 - .../expressions/filter/StringPredicate.java | 60 - .../expressions/filter/StringStartsWith.java | 41 - .../read/SupportsPushDownV2Filters.java | 38 +- .../util/V2ExpressionSQLBuilder.java | 100 +- .../apache/spark/sql/sources/filters.scala | 60 + .../expressions/TransformExtractorSuite.scala | 1 - .../catalyst/util/V2ExpressionBuilder.scala | 112 +- .../sql/execution/DataSourceScanExec.scala | 11 +- .../datasources/DataSourceStrategy.scala | 11 +- .../execution/datasources/jdbc/JDBCRDD.scala | 68 +- .../datasources/jdbc/JDBCRelation.scala | 18 +- .../datasources/v2/DataSourceV2Strategy.scala | 133 +- .../datasources/v2/PushDownUtils.scala | 17 +- .../datasources/v2/PushedDownOperators.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 8 +- .../datasources/v2/jdbc/JDBCScan.scala | 9 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 20 +- .../apache/spark/sql/jdbc/DB2Dialect.scala | 24 +- .../apache/spark/sql/jdbc/DerbyDialect.scala | 16 +- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 28 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 22 +- .../spark/sql/jdbc/MsSqlServerDialect.scala | 16 +- .../apache/spark/sql/jdbc/MySQLDialect.scala | 16 +- .../apache/spark/sql/jdbc/OracleDialect.scala | 28 +- .../spark/sql/jdbc/PostgresDialect.scala | 28 +- .../spark/sql/jdbc/TeradataDialect.scala | 28 +- .../JavaAdvancedDataSourceV2WithV2Filter.java | 75 +- .../sql/connector/DataSourceV2Suite.scala | 43 +- .../v2/DataSourceV2StrategySuite.scala | 43 + .../datasources/v2/V2FiltersSuite.scala | 204 -- .../datasources/v2/V2PredicateSuite.scala | 188 ++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 57 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 283 ++- 67 files changed, 1360 insertions(+), 3340 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryFilter.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dba74ac9bb217..cc148d9e247f6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,8 +34,75 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { - // Exclude rules for 3.2.x - lazy val v32excludes = v31excludes ++ Seq( + // Exclude rules for 3.4.x + lazy val v34excludes = v33excludes ++ Seq( + ) + + // Exclude rules for 3.3.x from 3.2.0 + lazy val v33excludes = v32excludes ++ Seq( + // [SPARK-35672][CORE][YARN] Pass user classpath entries to executors using config instead of command line + // The followings are necessary for Scala 2.13. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.CoarseGrainedExecutorBackend#Arguments.*"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.CoarseGrainedExecutorBackend#Arguments.*"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.CoarseGrainedExecutorBackend$Arguments$"), + + // [SPARK-37391][SQL] JdbcConnectionProvider tells if it modifies security context + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.jdbc.JdbcConnectionProvider.modifiesSecurityContext"), + + // [SPARK-37780][SQL] QueryExecutionListener support SQLConf as constructor parameter + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this"), + // [SPARK-37786][SQL] StreamingQueryListener support use SQLConf.get to get corresponding SessionState's SQLConf + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.this"), + // [SPARK-38432][SQL] Reactor framework so as JDBC dialect could compile filter by self way + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.toV2"), + + // [SPARK-37600][BUILD] Upgrade to Hadoop 3.3.2 + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4Compressor"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4Factory"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4SafeDecompressor") + ) + + // Exclude rules for 3.2.x from 3.1.1 + lazy val v32excludes = Seq( + // Spark Internals + ProblemFilters.exclude[Problem]("org.apache.spark.rpc.*"), + ProblemFilters.exclude[Problem]("org.spark-project.jetty.*"), + ProblemFilters.exclude[Problem]("org.spark_project.jetty.*"), + ProblemFilters.exclude[Problem]("org.sparkproject.jetty.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.internal.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.unused.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.unsafe.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.memory.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.util.collection.unsafe.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.errors.*"), + // DSv2 catalog and expression APIs are unstable yet. We should enable this back. + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.catalog.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.expressions.*"), + // Avro source implementation is internal. + ProblemFilters.exclude[Problem]("org.apache.spark.sql.v2.avro.*"), + + // [SPARK-34848][CORE] Add duration to TaskMetricDistributions + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this"), + + // [SPARK-34488][CORE] Support task Metrics Distributions and executor Metrics Distributions + // in the REST API call for a specified stage + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + + // [SPARK-36173][CORE] Support getting CPU number in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.cpus"), + + // [SPARK-35896] Include more granular metrics for stateful operators in StreamingQueryProgress + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), + + (problem: Problem) => problem match { + case MissingClassProblem(cls) => !cls.fullName.startsWith("org.sparkproject.jpmml") && + !cls.fullName.startsWith("org.sparkproject.dmg.pmml") + case _ => true + }, + // [SPARK-33808][SQL] DataSource V2: Build logical writes in the optimizer ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.connector.write.V1WriteBuilder"), @@ -72,1722 +139,10 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions") ) - // Exclude rules for 3.1.x - lazy val v31excludes = v30excludes ++ Seq( - // mima plugin update caused new incompatibilities to be detected - // core module - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleMapOutputWriter.commitAllPartitions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.environmentDetails"), - // mllib module - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.totalIterations"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.$init$"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.labels"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.truePositiveRateByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.falsePositiveRateByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.precisionByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.recallByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.accuracy"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedTruePositiveRate"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFalsePositiveRate"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.roc"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.areaUnderROC"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.pr"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.fMeasureByThreshold"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.precisionByThreshold"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.recallByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.FMClassifier.trainImpl"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.FMRegressor.trainImpl"), - // [SPARK-31077] Remove ChiSqSelector dependency on mllib.ChiSqSelectorModel - // private constructor - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.this"), - - // [SPARK-31127] Implement abstract Selector - // org.apache.spark.ml.feature.ChiSqSelectorModel type hierarchy change - // before: class ChiSqSelector extends Estimator with ChiSqSelectorParams - // after: class ChiSqSelector extends PSelector - // false positive, no binary incompatibility - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelector"), - - // [SPARK-24634] Add a new metric regarding number of inputs later than watermark plus allowed delay - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.$default$4"), - - //[SPARK-31893] Add a generic ClassificationSummary trait - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.weightCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.weightCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.weightCol"), - - // [SPARK-32879] Pass SparkSession.Builder options explicitly to SparkSession - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession.this") - ) - - // Exclude rules for 3.0.x - lazy val v30excludes = v24excludes ++ Seq( - // [SPARK-23429][CORE] Add executor memory metrics to heartbeat and expose in executors REST API - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.this"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate$"), - - // [SPARK-29306] Add support for Stage level scheduling for executors - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productPrefix"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.toString"), - - // [SPARK-29399][core] Remove old ExecutorPlugin interface. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ExecutorPlugin"), - - // [SPARK-28980][SQL][CORE][MLLIB] Remove more old deprecated items in Spark 3 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.clustering.KMeans.train"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.classification.LogisticRegressionWithSGD$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.classification.LogisticRegressionWithSGD.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.RidgeRegressionWithSGD$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.RidgeRegressionWithSGD.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.LassoWithSGD.this"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LassoWithSGD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD$"), - - // [SPARK-28486][CORE][PYTHON] Map PythonBroadcast's data file to a BroadcastBlock to avoid delete by GC - ProblemFilters.exclude[InaccessibleMethodProblem]("java.lang.Object.finalize"), - - // [SPARK-27366][CORE] Support GPU Resources in Spark job scheduling - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resources"), - - // [SPARK-29417][CORE] Resource Scheduling - add TaskContext.resource java api - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resourcesJMap"), - - // [SPARK-27410][MLLIB] Remove deprecated / no-op mllib.KMeans getRuns, setRuns - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.getRuns"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.setRuns"), - - // [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.UnaryTransformer.this"), - - // [SPARK-27090][CORE] Removing old LEGACY_DRIVER_IDENTIFIER ("") - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.LEGACY_DRIVER_IDENTIFIER"), - - // [SPARK-25838] Remove formatVersion from Saveable - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.LocalLDAModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeansModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.PowerIterationClusteringModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.GaussianMixtureModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.recommendation.MatrixFactorizationModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.Word2VecModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.classification.SVMModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.classification.LogisticRegressionModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.classification.NaiveBayesModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.Saveable.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.FPGrowthModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.PrefixSpanModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.IsotonicRegressionModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.RidgeRegressionModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.LassoModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionModel.formatVersion"), - - // [SPARK-26132] Remove support for Scala 2.11 in Spark 3.0.0 - ProblemFilters.exclude[DirectAbstractMethodProblem]("scala.concurrent.Future.transformWith"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("scala.concurrent.Future.transform"), - - // [SPARK-26254][CORE] Extract Hive + Kafka dependencies from Core. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.deploy.security.HiveDelegationTokenProvider"), - - // [SPARK-26329][CORE] Faster polling of executor memory metrics. - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd.apply"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd.copy$default$6"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd.this"), - - // [SPARK-26311][CORE]New feature: apply custom log URL pattern for executor log URLs - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerApplicationStart.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerApplicationStart.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerApplicationStart.this"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerApplicationStart$"), - - // [SPARK-27630][CORE] Properly handle task end events from completed stages - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerSpeculativeTaskSubmitted.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerSpeculativeTaskSubmitted.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerSpeculativeTaskSubmitted.this"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerSpeculativeTaskSubmitted$"), - - // [SPARK-26632][Core] Separate Thread Configurations of Driver and Executor - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"), - - // [SPARK-16872][ML][PYSPARK] Impl Gaussian Naive Bayes Classifier - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.NaiveBayesModel.this"), - - // [SPARK-25765][ML] Add training cost to BisectingKMeans summary - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel.this"), - - // [SPARK-24243][CORE] Expose exceptions from InProcessAppHandle - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.launcher.SparkAppHandle.getError"), - - // [SPARK-25867] Remove KMeans computeCost - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), - - // [SPARK-26127] Remove deprecated setters from tree regression and classification models - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxIter"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setStepSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setNumTrees"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxIter"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setStepSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setNumTrees"), - - // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.stat.test.BinarySampleBeanInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LabeledPointBeanInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), - - // [SPARK-28780][ML] Delete the incorrect setWeightCol method in LinearSVCModel - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LinearSVCModel.setWeightCol"), - - // [SPARK-29645][ML][PYSPARK] ML add param RelativeError - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.relativeError"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getRelativeError"), - - // [SPARK-28968][ML] Add HasNumFeatures in the scala side - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.FeatureHasher.getNumFeatures"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.FeatureHasher.numFeatures"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.HashingTF.getNumFeatures"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.HashingTF.numFeatures"), - - // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleBytesWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleWriteTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.fMeasure"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.recall"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.precision"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.context"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.context"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.GeneralMLWriter.context"), - - // [SPARK-25737] Remove JavaSparkContextVarargsWorkaround - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.api.java.JavaSparkContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.union"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.union"), - - // [SPARK-16775] Remove deprecated accumulator v1 APIs - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulator"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulator$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulableParam"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$FloatAccumulatorParam$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$DoubleAccumulatorParam$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$LongAccumulatorParam$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$IntAccumulatorParam$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulable"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulableCollection"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulator"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.LegacyAccumulatorWrapper"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.intAccumulator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulable"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.doubleAccumulator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulator"), - - // [SPARK-24109] Remove class SnappyOutputStreamWrapper - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"), - - // [SPARK-19287] JavaPairRDD flatMapValues requires function returning Iterable, not Iterator - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"), - - // [SPARK-25680] SQL execution listener shouldn't happen on execution thread - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this"), - - // [SPARK-25862][SQL] Remove rangeBetween APIs introduced in SPARK-21608 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.unboundedFollowing"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.unboundedPreceding"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.currentRow"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.Window.rangeBetween"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.rangeBetween"), - - // [SPARK-23781][CORE] Merge token renewer functionality into HadoopDelegationTokenManager - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.nextCredentialRenewalTime"), - - // [SPARK-26133][ML] Remove deprecated OneHotEncoder and rename OneHotEncoderEstimator to OneHotEncoder - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.OneHotEncoderEstimator"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.OneHotEncoder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.OneHotEncoderEstimator$"), - - // [SPARK-30329][ML] add iterator/foreach methods for Vectors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.activeIterator"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.activeIterator"), - - // [SPARK-26141] Enable custom metrics implementation in shuffle write - // Following are Java private classes - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"), - - // [SPARK-26139] Implement shuffle write metrics in SQL - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ShuffleDependency.this"), - - // [SPARK-26362][CORE] Remove 'spark.driver.allowMultipleContexts' to disallow multiple creation of SparkContexts - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.setActiveContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.markPartiallyConstructed"), - - // [SPARK-26457] Show hadoop configurations in HistoryServer environment tab - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this"), - - // [SPARK-30144][ML] Make MultilayerPerceptronClassificationModel extend MultilayerPerceptronParams - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.layers"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"), - - // [SPARK-30630][ML] Remove numTrees in GBT - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.numTrees"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.numTrees"), - - // Data Source V2 API changes - (problem: Problem) => problem match { - case MissingClassProblem(cls) => - !cls.fullName.startsWith("org.apache.spark.sql.sources.v2") - case _ => true - }, - - // [SPARK-27521][SQL] Move data source v2 to catalyst module - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarBatch"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ArrowColumnVector"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarRow"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarArray"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarMap"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnVector"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThanOrEqual"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringEndsWith"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThanOrEqual$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.In$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Not"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNotNull"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThan"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThanOrEqual"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualNullSafe$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThan$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.In"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.And"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringStartsWith$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualNullSafe"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringEndsWith$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThanOrEqual$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Not$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNull$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThan$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNotNull$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Or"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualTo$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThan"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringContains"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Filter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNull"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualTo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.And$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Or$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringStartsWith"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringContains$"), - - // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.inputTypes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullableTypes_="), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.dataType"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.f"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.this"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$2"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$1"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3"), - - // [SPARK-11215][ML] Add multiple columns support to StringIndexer - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), - - // [SPARK-26616][MLlib] Expose document frequency in IDFModel - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf"), - - // [SPARK-28199][SS] Remove deprecated ProcessingTime - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.ProcessingTime"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.ProcessingTime$"), - - // [SPARK-25382][SQL][PYSPARK] Remove ImageSchema.readImages in 3.0 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.image.ImageSchema.readImages"), - - // [SPARK-25341][CORE] Support rolling back a shuffle map stage and re-generate the shuffle files - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleBlockId.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.apply"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.mapId"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.mapId"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleBlockId.mapId"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.mapId"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.FetchFailed$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.apply"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.copy$default$5"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.copy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.copy$default$3"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.this"), - - // [SPARK-28957][SQL] Copy any "spark.hive.foo=bar" spark properties into hadoop conf as "hive.foo=bar" - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.appendS3AndSparkHadoopConfigurations"), - - // [SPARK-29348] Add observable metrics. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryProgress.this"), - - // [SPARK-30377][ML] Make AFTSurvivalRegression extend Regressor - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.setFeaturesCol"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.setPredictionCol"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setFeaturesCol"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setLabelCol"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setPredictionCol"), - - // [SPARK-29543][SS][UI] Init structured streaming ui - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStartedEvent.this"), - - // [SPARK-30667][CORE] Add allGather method to BarrierTaskContext - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.RequestToSync") - ) - - // Exclude rules for 2.4.x - lazy val v24excludes = v23excludes ++ Seq( - // [SPARK-25248] add package private methods to TaskContext - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskFailed"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markInterrupted"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.fetchFailed"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskCompleted"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperties"), - - // [SPARK-10697][ML] Add lift to Association rules - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"), - - // [SPARK-24296][CORE] Replicate large blocks as a stream. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"), - // [SPARK-23528] Add numIter to ClusteringSummary - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), - // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), - - // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"), - - // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printVersionAndExit"), - - // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), - - // [SPARK-20659] Remove StorageStatus, or make it private - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOnHeapStorageMemory"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOnHeapStorageMemory"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.getExecutorStorageStatus"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numBlocks"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocks"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.containsBlock"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddBlocksById"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), - - // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), - - // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), - - // [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter"), - - // [SPARK-21842][MESOS] Support Kerberos ticket renewal and creation in Mesos - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getDateOfNextUpdate"), - - // [SPARK-23366] Improve hot reading path in ReadAheadInputStream - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.ReadAheadInputStream.this"), - - // [SPARK-22941][CORE] Do not exit JVM when submit fails with in-process launcher. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.addJarToClasspath"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.mergeFileLists"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment$default$2"), - - // Data Source V2 API changes - // TODO: they are unstable APIs and should not be tracked by mima. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ReadSupportWithSchema"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.createDataReaderFactories"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.createBatchDataReaderFactories"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.planBatchInputPartitions"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanUnsafeRow"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader.createDataReaderFactories"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader.planInputPartitions"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownCatalystFilters"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataReader"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.getStatistics"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.estimateStatistics"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataReaderFactory"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createDataWriter"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createDataWriter"), - - // Changes to HasRawPredictionCol. - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.rawPredictionCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.org$apache$spark$ml$param$shared$HasRawPredictionCol$_setter_$rawPredictionCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.getRawPredictionCol"), - - // [SPARK-15526][ML][FOLLOWUP] Make JPMML provided scope to avoid including unshaded JARs - (problem: Problem) => problem match { - case MissingClassProblem(cls) => - !cls.fullName.startsWith("org.sparkproject.jpmml") && - !cls.fullName.startsWith("org.sparkproject.dmg.pmml") && - !cls.fullName.startsWith("org.spark_project.jpmml") && - !cls.fullName.startsWith("org.spark_project.dmg.pmml") - case _ => true - } - ) - - // Exclude rules for 2.3.x - lazy val v23excludes = v22excludes ++ Seq( - // [SPARK-22897] Expose stageAttemptId in TaskContext - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), - - // SPARK-22789: Map-only continuous processing execution - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$9"), - - // SPARK-22372: Make cluster submission use SparkApplication. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getSecretKeyFromUserCredentials"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.isYarnMode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getCurrentUserCredentials"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.addSecretKeyToUserCredentials"), - - // SPARK-18085: Better History Server scalability for many / large applications - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.env.EnvironmentListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.exec.ExecutorsListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.storage.StorageListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorStageSummary.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkStatusTracker.this"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.jobs.JobProgressListener"), - - // [SPARK-20495][SQL] Add StorageLevel to cacheTable API - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), - - // [SPARK-19937] Add remote bytes read to disk. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetricDistributions.this"), - - // [SPARK-21276] Update lz4-java to the latest (v1.4.0) - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.io.LZ4BlockInputStream"), - - // [SPARK-17139] Add model summary for MultinomialLogisticRegression - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictionCol"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.labels"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.truePositiveRateByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.falsePositiveRateByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.precisionByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.recallByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.accuracy"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedTruePositiveRate"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFalsePositiveRate"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.asBinary"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_="), - - // [SPARK-14280] Support Scala 2.12 - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"), - - // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), - - // [SPARK-21728][CORE] Allow SparkSubmit to use Logging - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFileList"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFile"), - - // [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment"), - - // [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0 - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.network.util.AbstractFileRegion.transfered"), - - // [SPARK-20643][CORE] Add listener implementation to collect app state - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$5"), - - // [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$12"), - - // [SPARK-21462][SS] Added batchId to StreamingQueryProgress.json - // [SPARK-21409][SS] Expose state store memory usage in SQL metrics and progress updates - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), - - // [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentWatermarkMs"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentProcessingTimeMs"), - - // [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="), - - // [SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid") - ) - - // Exclude rules for 2.2.x - lazy val v22excludes = v21excludes ++ Seq( - // [SPARK-20355] Add per application spark version on the history server headerpage - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), - - // [SPARK-19652][UI] Do auth checks for REST API access. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"), - - // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"), - - // [SPARK-18949] [SQL] Add repairTable API to Catalog - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.recoverPartitions"), - - // [SPARK-18537] Add a REST api to spark streaming - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.streaming.scheduler.StreamingListener.onStreamingStarted"), - - // [SPARK-19148][SQL] do not expose the external table concept in Catalog - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable"), - - // [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"), - - // [SPARK-19267] Fetch Failure handling robust to user error handling - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"), - - // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$10"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$11"), - - // [SPARK-17161] Removing Python-friendly constructors not needed - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), - - // [SPARK-19820] Allow reason to be specified to task kill - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.TaskKilled$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.countTowardsTaskFailures"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.toErrorString"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.TaskKilled.toString"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.killTaskIfInterrupted"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getKillReason"), - - // [SPARK-19876] Add one time trigger, and improve Trigger APIs - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime"), - - // [SPARK-17471][ML] Add compressed method to ML matrices - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressed"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedColMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedRowMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isRowMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isColMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSparseSizeInBytes"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDense"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparse"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseRowMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseRowMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseColMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getDenseSizeInBytes"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"), - - // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum") - ) ++ Seq( - // [SPARK-17019] Expose on-heap and off-heap memory usage in various places - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.this"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.StorageStatus.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.RDDDataDistribution.this") - ) - - // Exclude rules for 2.1.x - lazy val v21excludes = v20excludes ++ { - Seq( - // [SPARK-17671] Spark 2.0 history server summary page is slow even set spark.history.ui.maxApplications - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.deploy.history.HistoryServer.getApplicationList"), - // [SPARK-14743] Improve delegation token handling in secure cluster - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTimeFromNowToRenewal"), - // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"), - // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select"), - // [SPARK-16967] Move Mesos to Module - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX"), - // [SPARK-16240] ML persistence backward compatibility for LDA - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$"), - // [SPARK-17717] Add Find and Exists method to Catalog. - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getDatabase"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getTable"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getFunction"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), - - // [SPARK-17731][SQL][Streaming] Metrics for structured streaming - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceStatus.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.SourceStatus.offsetDesc"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.status"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SinkStatus.this"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryInfo"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.queryInfo"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.queryInfo"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.queryInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryStarted"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryStarted"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryStarted"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryProgress"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryProgress"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryTerminated"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryTerminated"), - - // [SPARK-18516][SQL] Split state and progress in streaming - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.SourceStatus"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.SinkStatus"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.sinkStatus"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.sourceStatuses"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQuery.id"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.lastProgress"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.recentProgress"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.id"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.get"), - - // [SPARK-17338][SQL] add global temp view - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropGlobalTempView"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), - - // [SPARK-18034] Upgrade to MiMa 0.1.11 to fix flakiness. - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.aggregationDepth"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.getAggregationDepth"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.org$apache$spark$ml$param$shared$HasAggregationDepth$_setter_$aggregationDepth_="), - - // [SPARK-18236] Reduce duplicate objects in Spark UI and HistoryServer - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.TaskInfo.accumulables"), - - // [SPARK-18657] Add StreamingQuery.runId - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.runId"), - - // [SPARK-18694] Add StreamingQuery.explain and exception to Python and fix StreamingQueryException - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryException$"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.startOffset"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.endOffset"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.query") - ) - } - - // Exclude rules for 2.0.x - lazy val v20excludes = { - Seq( - ProblemFilters.exclude[Problem]("org.apache.spark.rpc.*"), - ProblemFilters.exclude[Problem]("org.spark-project.jetty.*"), - ProblemFilters.exclude[Problem]("org.spark_project.jetty.*"), - ProblemFilters.exclude[Problem]("org.sparkproject.jetty.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.internal.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.unused.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.unsafe.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.memory.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.util.collection.unsafe.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"), - // SPARK-14042 Add custom coalescer support - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.coalesce"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rdd.PartitionCoalescer$LocationIterator"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.rdd.PartitionCoalescer"), - // SPARK-15532 Remove isRootContext flag from SQLContext. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.isRootContext"), - // SPARK-12600 Remove SQL deprecated methods - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.applySchema"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.parquetFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jdbc"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.dialectClassName"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.getSQLDialect"), - // SPARK-13664 Replace HadoopFsRelation with FileFormat - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.LibSVMRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelationProvider"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelation$FileStatusCache"), - // SPARK-15543 Rename DefaultSources to make them more self-describing - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.DefaultSource") - ) ++ Seq( - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory"), - // SPARK-14358 SparkListener from trait to abstract class - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.addSparkListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.JavaSparkListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkFirehoseListener"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.scheduler.SparkListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.jobs.JobProgressListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.exec.ExecutorsListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.env.EnvironmentListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.storage.StorageListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.StorageStatusListener") - ) ++ - Seq( - // SPARK-3369 Fix Iterable/Iterator in Java API - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapFunction2.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapFunction2.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.PairFlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.PairFlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.CoGroupFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.CoGroupFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.MapPartitionsFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.MapPartitionsFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapGroupsFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapGroupsFunction.call") - ) ++ - Seq( - // [SPARK-6429] Implement hashCode and equals together - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Partition.org$apache$spark$Partition$$super=uals") - ) ++ - Seq( - // SPARK-4819 replace Guava Optional - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getSparkHome"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner") - ) ++ - Seq( - // SPARK-12481 Remove Hadoop 1.x - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil"), - // SPARK-12615 Remove deprecated APIs in core - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.$default$6"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.numericRDDToDoubleRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intToIntWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.writableWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToAsyncRDDActions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.boolToBoolWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longToLongWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToOrderedRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.booleanWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringToText"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleToDoubleWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToSequenceFileRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesToBytesWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatToFloatWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addOnCompleteCallback"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.runningLocally"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.attemptId"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.defaultMinSplits"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.runJob"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.runJob"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.tachyonFolderName"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.initLocalProperties"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearJars"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearFiles"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.toArray"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.filterWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.foreachWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.SequenceFileRDDFunctions.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.splits"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.externalBlockStoreFolderName"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore") - ) ++ Seq( - // SPARK-12149 Added new fields to ExecutorSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") - ) ++ - // SPARK-12665 Remove deprecated and unused classes - Seq( - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.graphx.GraphKryoRegistrator"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") - ) ++ Seq( - // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge") - ) ++ Seq( - // SPARK-12510 Refactor ActorReceiver to support Java - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") - ) ++ Seq( - // SPARK-12895 Implement TaskMetrics using accumulators - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators") - ) ++ Seq( - // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Accumulator.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.Accumulator.initialValue") - ) ++ Seq( - // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") - ) ++ Seq( - // SPARK-12689 Migrate DDL parsing to the newly absorbed parser - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser") - ) ++ Seq( - // SPARK-7799 Add "streaming-akka" project - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$6"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$5"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$3"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.streaming.zeromq.ZeroMQReceiver"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver$Supervisor") - ) ++ Seq( - // SPARK-12348 Remove deprecated Streaming APIs. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.dstream.DStream.foreach"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.awaitTermination"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.networkStream"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.api.java.JavaStreamingContextFactory"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.awaitTermination"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.sc"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreach"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.getOrCreate") - ) ++ Seq( - // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") - ) ++ Seq( - // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") - ) ++ Seq( - // SPARK-6363 Make Scala 2.11 the default Scala version - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") - ) ++ Seq( - // SPARK-7889 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"), - // SPARK-13296 - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$") - ) ++ Seq( - // SPARK-12995 Remove deprecated APIs in graphx - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.lib.SVDPlusPlus.runSVDPlusPlus"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") - ) ++ Seq( - // SPARK-13426 Remove the support of SIMR - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") - ) ++ Seq( - // SPARK-13413 Remove SparkContext.metricsSystem/schedulerBackend_ setter - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metricsSystem"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.schedulerBackend_=") - ) ++ Seq( - // SPARK-13220 Deprecate yarn-client and yarn-cluster mode - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") - ) ++ Seq( - // SPARK-13465 TaskContext. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") - ) ++ Seq ( - // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") - ) ++ Seq( - // SPARK-13526 Move SQLContext per-session states to new class - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.UDFRegistration.this") - ) ++ Seq( - // [SPARK-13486][SQL] Move SQLConf into an internal package - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") - ) ++ Seq( - //SPARK-11011 UserDefinedType serialization should be strongly typed - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), - // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition") - ) ++ Seq( - // [SPARK-13244][SQL] Migrates DataFrame to Dataset - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.tables"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.sql"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.baseRelationToDataFrame"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.table"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrame.apply"), - - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.localSeqToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.stringRddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.rddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), - - // [SPARK-14451][SQL] Move encoder definition into Aggregator interface - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"), - - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions") - ) ++ Seq( - // [SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam` to (Streaming)LinearRegressionWithSGD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD.this") - ) ++ Seq( - // SPARK-15250 Remove deprecated json API in DataFrameReader - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameReader.json") - ) ++ Seq( - // SPARK-13920: MIMA checks should apply to @Experimental and @DeveloperAPI APIs - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineCombinersByKey"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineValuesByKey"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.run"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.runJob"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.actorSystem"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.cacheManager"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getConfigurationFromJobContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTaskAttemptIDFromTaskAttemptContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.newConfiguration"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback_="), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.setBytesReadCallback"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.updateBytesRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decFetchWaitTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decLocalBlocksFetched"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRecordsRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBlocksFetched"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBytesRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleBytesWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleWriteTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleBytesWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleWriteTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.setShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.PCAModel.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.taskMetrics"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.TaskInfo.attempt"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.ExperimentalMethods.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUDF"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUdf"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.cumeDist"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.denseRank"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.inputFileName"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.isNaN"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.percentRank"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.rowNumber"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.sparkPartitionId"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.externalBlockStoreSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsedByRdd"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatusListener.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.scheduler.BatchInfo.streamIdToNumRecords"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.storageStatusList"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.storage.StorageListener.storageStatusList"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.InputMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.OutputMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transformImpl"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.GBTClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayes.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRest.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRestModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.RandomForestClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeans.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logLikelihood"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logPerplexity"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.RegressionEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Binarizer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Bucketizer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelector.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.HashingTF.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDF.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDFModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IndexToString.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Interaction.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCA.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCAModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormula.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormulaModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.SQLTransformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StopWordsRemover.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorAssembler.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorSlicer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2Vec.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALS.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.GBTRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.extractWeightedLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.extractWeightedLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegression.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionTrainingSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplit.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.RegressionMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameWriter.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.broadcast"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.callUDF"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.InsertableRelation.insert"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.fMeasureByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.pr"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.precisionByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.predictions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.recallByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.roc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.describeTopics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.getVectors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.itemFactors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.userFactors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.predictions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.residuals"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.name"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.value"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.drop"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.fill"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.replace"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.jdbc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.json"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.load"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.orc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.parquet"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.table"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.text"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.crosstab"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.freqItems"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.sampleBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.createExternalTable"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.emptyDataFrame"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.range"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.functions.udf"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.JobLogger"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorHelper"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.functions$"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Predictor.train"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert") - ) ++ Seq( - // [SPARK-13926] Automatically use Kryo serializer when shuffling RDDs with simple types - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ShuffleDependency.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ShuffleDependency.serializer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.Serializer$") - ) ++ Seq( - // SPARK-13927: add row/column iterator to local matrices - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter") - ) ++ Seq( - // SPARK-13948: MiMa Check should catch if the visibility change to `private` - // TODO(josh): Some of these may be legitimate incompatibilities; we should follow up before the 2.0.0 release - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.toDS"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.askTimeout"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.lookupTimeout"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.UnaryTransformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.select"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.toDF"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Logging.initializeLogIfNecessary"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerEvent.logEvent"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance") - ) ++ Seq( - // [SPARK-14014] Replace existing analysis.Catalog with SessionCatalog - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.this") - ) ++ Seq( - // [SPARK-13928] Move org.apache.spark.Logging into org.apache.spark.internal.Logging - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Logging"), - (problem: Problem) => problem match { - case MissingTypesProblem(_, missing) - if missing.map(_.fullName).sameElements(Seq("org.apache.spark.Logging")) => false - case _ => true - } - ) ++ Seq( - // [SPARK-13990] Automatically pick serializer when caching RDDs - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.uploadBlock") - ) ++ Seq( - // [SPARK-14089][CORE][MLLIB] Remove methods that has been deprecated since 1.1, 1.2, 1.3, 1.4, and 1.5 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.getThreadLocal"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeReduce"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeAggregate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.defaultStategy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.saveLabeledData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol") - ) ++ Seq( - // [SPARK-14205][SQL] remove trait Queryable - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset") - ) ++ Seq( - // [SPARK-11262][ML] Unit test for gradient, loss layers, memory management - // for multilayer perceptron. - // This class is marked as `private`. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction") - ) ++ Seq( - // [SPARK-13674][SQL] Add wholestage codegen support to Sample - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this") - ) ++ Seq( - // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") - ) ++ Seq( - // [SPARK-14437][Core] Use the address that NettyBlockTransferService listens to create BlockManagerId - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.this") - ) ++ Seq( - // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") - ) ++ Seq( - // [SPARK-14475] Propagate user-defined context from driver to executors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), - // [SPARK-14617] Remove deprecated APIs in TaskMetrics - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$"), - // [SPARK-14628] Simplify task metrics by always tracking read/write metrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.readMethod"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.writeMethod") - ) ++ Seq( - // SPARK-14628: Always track input/output/shuffle metrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.totalBlocksFetched"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.inputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.outputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleWriteMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleReadMetrics"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.inputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.outputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleWriteMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleReadMetrics"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") - ) ++ Seq( - // SPARK-13643: Move functionality from SQLContext to SparkSession - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.getSchema") - ) ++ Seq( - // [SPARK-14407] Hides HadoopFsRelation related data source API into execution package - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriterFactory") - ) ++ Seq( - // SPARK-14734: Add conversions between mllib and ml Vector, Matrix types - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asML"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asML") - ) ++ Seq( - // SPARK-14704: Create accumulators in TaskMetrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.this") - ) ++ Seq( - // SPARK-14861: Replace internal usages of SQLContext with SparkSession - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.LocalLDAModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.DistributedLDAModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.LDAModel.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.ml.clustering.LDAModel.sqlContext"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.Dataset.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.DataFrameReader.this") - ) ++ Seq( - // SPARK-14542 configurable buffer size for pipe RDD - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.pipe"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.pipe") - ) ++ Seq( - // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable") - ) ++ Seq( - // [SPARK-14952][Core][ML] Remove methods deprecated in 1.6 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.input.PortableDataStream.close"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.weights"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionModel.weights") - ) ++ Seq( - // [SPARK-10653] [Core] Remove unnecessary things from SparkEnv - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.sparkFilesDir"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.blockTransferService") - ) ++ Seq( - // SPARK-14654: New accumulator API - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.remoteBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.totalBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.localBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.remoteBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.localBlocksFetched") - ) ++ Seq( - // [SPARK-14615][ML] Use the new ML Vector and Matrix in the ML pipeline based algorithms - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.getOldDocConcentration"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.estimatedDocConcentration"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.topicsMatrix"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.clusterCenters"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LabelConverter.decodeLabel"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LabelConverter.encodeLabeledPoint"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.weights"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.predict"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.predictRaw"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.raw2probabilityInPlace"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.theta"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.pi"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.probability2prediction"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.predictRaw"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.raw2prediction"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.raw2probabilityInPlace"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.predict"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.coefficients"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.raw2prediction"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.predictRaw"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.ClassificationModel.predictRaw"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.ElementwiseProduct.getScalingVec"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ElementwiseProduct.setScalingVec"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.PCAModel.pc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.originalMax"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.originalMin"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.IDFModel.idf"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.mean"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.std"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.predict"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.coefficients"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.predictQuantiles"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.predictions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.boundaries"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.predict"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.coefficients"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.this") - ) ++ Seq( - // [SPARK-15290] Move annotations, like @Since / @DeveloperApi, into spark-tags - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.package$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.package"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Private"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.AlphaComponent"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Experimental"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.DeveloperApi") - ) ++ Seq( - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asBreeze"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asBreeze") - ) ++ Seq( - // [SPARK-15914] Binary compatibility is broken since consolidation of Dataset and DataFrame - // in Spark 2.0. However, source level compatibility is still maintained. - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.load"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jsonFile"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jdbc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.parquetFile"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.applySchema") - ) ++ Seq( - // SPARK-17096: Improve exception string reported through the StreamingQueryListener - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this") - ) ++ Seq( - // SPARK-17406 limit timeline executor events - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorIdToData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksActive"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksComplete"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputRecords"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksFailed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleWrite"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToDuration"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputBytes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToLogUrls"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputBytes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputRecords"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTotalCores"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksMax"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToJvmGCTime") - ) ++ Seq( - // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") - ) ++ Seq( - // [SPARK-17498] StringIndexer enhancement for handling unseen labels - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") - ) ++ Seq( - // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") - ) ++ Seq( - // [SPARK-12221] Add CPU time to metrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") - ) ++ Seq( - // [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") - ) ++ Seq( - // [SPARK-21680][ML][MLLIB]optimize Vector compress - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.toSparseWithSize"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.toSparseWithSize") - ) ++ Seq( - // [SPARK-3181][ML]Implement huber loss for LinearRegression. - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.org$apache$spark$ml$param$shared$HasLoss$_setter_$loss_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.getLoss"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.loss") - ) - } - def excludes(version: String) = version match { + case v if v.startsWith("3.4") => v34excludes + case v if v.startsWith("3.3") => v33excludes case v if v.startsWith("3.2") => v32excludes - case v if v.startsWith("3.1") => v31excludes - case v if v.startsWith("3.0") => v30excludes - case v if v.startsWith("2.4") => v24excludes - case v if v.startsWith("2.3") => v23excludes - case v if v.startsWith("2.2") => v22excludes - case v if v.startsWith("2.1") => v21excludes - case v if v.startsWith("2.0") => v20excludes case _ => Seq() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java index 9f6c0975ae0e1..76dfe73f666cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions; +import java.util.Arrays; + import org.apache.spark.annotation.Evolving; /** @@ -26,8 +28,23 @@ */ @Evolving public interface Expression { + Expression[] EMPTY_EXPRESSION = new Expression[0]; + /** * Format the expression as a human readable SQL-like string. */ default String describe() { return this.toString(); } + + /** + * Returns an array of the children of this node. Children should not change. + */ + Expression[] children(); + + /** + * List of fields or columns that are referenced by this expression. + */ + default NamedReference[] references() { + return Arrays.stream(children()).map(e -> e.references()) + .flatMap(Arrays::stream).distinct().toArray(NamedReference[]::new); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index b3dd2cbfe3d7d..8952761f9ef34 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -19,77 +19,19 @@ import java.io.Serializable; import java.util.Arrays; +import java.util.Objects; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; -// scalastyle:off line.size.limit /** * The general representation of SQL scalar expressions, which contains the upper-cased - * expression name and all the children expressions. + * expression name and all the children expressions. Please also see {@link Predicate} + * for the supported predicate expressions. *

    * The currently supported SQL scalar expressions: *

      - *
    1. Name: IS_NULL - *
        - *
      • SQL semantic: expr IS NULL
      • - *
      • Since version: 3.3.0
      • - *
      - *
    2. - *
    3. Name: IS_NOT_NULL - *
        - *
      • SQL semantic: expr IS NOT NULL
      • - *
      • Since version: 3.3.0
      • - *
      - *
    4. - *
    5. Name: = - *
        - *
      • SQL semantic: expr1 = expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    6. - *
    7. Name: != - *
        - *
      • SQL semantic: expr1 != expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    8. - *
    9. Name: <> - *
        - *
      • SQL semantic: expr1 <> expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    10. - *
    11. Name: <=> - *
        - *
      • SQL semantic: expr1 <=> expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    12. - *
    13. Name: < - *
        - *
      • SQL semantic: expr1 < expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    14. - *
    15. Name: <= - *
        - *
      • SQL semantic: expr1 <= expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    16. - *
    17. Name: > - *
        - *
      • SQL semantic: expr1 > expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    18. - *
    19. Name: >= - *
        - *
      • SQL semantic: expr1 >= expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    20. *
    21. Name: + *
        *
      • SQL semantic: expr1 + expr2
      • @@ -138,24 +80,6 @@ *
      • Since version: 3.3.0
      • *
      *
    22. - *
    23. Name: AND - *
        - *
      • SQL semantic: expr1 AND expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    24. - *
    25. Name: OR - *
        - *
      • SQL semantic: expr1 OR expr2
      • - *
      • Since version: 3.3.0
      • - *
      - *
    26. - *
    27. Name: NOT - *
        - *
      • SQL semantic: NOT expr
      • - *
      • Since version: 3.3.0
      • - *
      - *
    28. *
    29. Name: ~ *
        *
      • SQL semantic: ~ expr
      • @@ -176,7 +100,6 @@ * * @since 3.3.0 */ -// scalastyle:on line.size.limit @Evolving public class GeneralScalarExpression implements Expression, Serializable { private String name; @@ -190,6 +113,19 @@ public GeneralScalarExpression(String name, Expression[] children) { public String name() { return name; } public Expression[] children() { return children; } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + GeneralScalarExpression that = (GeneralScalarExpression) o; + return Objects.equals(name, that.name) && Arrays.equals(children, that.children); + } + + @Override + public int hashCode() { + return Objects.hash(name, children); + } + @Override public String toString() { V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java index df9e58fa319fd..5e8aeafe74515 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java @@ -40,4 +40,7 @@ public interface Literal extends Expression { * Returns the SQL data type of the literal. */ DataType dataType(); + + @Override + default Expression[] children() { return EMPTY_EXPRESSION; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java index 167432fa0e86a..8c0f029a35832 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java @@ -32,4 +32,10 @@ public interface NamedReference extends Expression { * Each string in the returned array represents a field name. */ String[] fieldNames(); + + @Override + default Expression[] children() { return EMPTY_EXPRESSION; } + + @Override + default NamedReference[] references() { return new NamedReference[]{ this }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java index 72252457df26e..51401786ca5d7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java @@ -40,4 +40,7 @@ public interface SortOrder extends Expression { * Returns the null ordering. */ NullOrdering nullOrdering(); + + @Override + default Expression[] children() { return new Expression[]{ expression() }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java index 297205825c6a4..e9ead7fc5fd2a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java @@ -34,13 +34,11 @@ public interface Transform extends Expression { */ String name(); - /** - * Returns all field references in the transform arguments. - */ - NamedReference[] references(); - /** * Returns the arguments passed to the transform function. */ Expression[] arguments(); + + @Override + default Expression[] children() { return arguments(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java index cc9d27ab8e59c..d09e5f7ba28a3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java @@ -38,6 +38,9 @@ public Avg(Expression column, boolean isDistinct) { public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } + @Override + public Expression[] children() { return new Expression[]{ input }; } + @Override public String toString() { if (isDistinct) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java index 54c64b83c5d52..c840b29ad2546 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java @@ -38,6 +38,9 @@ public Count(Expression column, boolean isDistinct) { public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } + @Override + public Expression[] children() { return new Expression[]{ input }; } + @Override public String toString() { if (isDistinct) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java index 13801194b63cb..ff8639cbd05a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the number of rows in a group. @@ -30,6 +31,9 @@ public final class CountStar implements AggregateFunc { public CountStar() { } + @Override + public Expression[] children() { return EMPTY_EXPRESSION; } + @Override public String toString() { return "COUNT(*)"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java index 0ff26c8875b7a..7016644543447 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -22,7 +22,6 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; -import org.apache.spark.sql.connector.expressions.NamedReference; /** * The general implementation of {@link AggregateFunc}, which contains the upper-cased function @@ -46,21 +45,23 @@ public final class GeneralAggregateFunc implements AggregateFunc { private final String name; private final boolean isDistinct; - private final NamedReference[] inputs; + private final Expression[] children; public String name() { return name; } public boolean isDistinct() { return isDistinct; } - public NamedReference[] inputs() { return inputs; } - public GeneralAggregateFunc(String name, boolean isDistinct, NamedReference[] inputs) { + public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) { this.name = name; this.isDistinct = isDistinct; - this.inputs = inputs; + this.children = children; } + @Override + public Expression[] children() { return children; } + @Override public String toString() { - String inputsString = Arrays.stream(inputs) + String inputsString = Arrays.stream(children) .map(Expression::describe) .collect(Collectors.joining(", ")); if (isDistinct) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java index 971aac279e09b..089d2bd751763 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java @@ -33,6 +33,9 @@ public final class Max implements AggregateFunc { public Expression column() { return input; } + @Override + public Expression[] children() { return new Expression[]{ input }; } + @Override public String toString() { return "MAX(" + input.describe() + ")"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java index 8d0644b0f0103..253cdea41dd76 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java @@ -33,6 +33,9 @@ public final class Min implements AggregateFunc { public Expression column() { return input; } + @Override + public Expression[] children() { return new Expression[]{ input }; } + @Override public String toString() { return "MIN(" + input.describe() + ")"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java index 721ef31c9a817..4e01b92d8c369 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java @@ -38,6 +38,9 @@ public Sum(Expression column, boolean isDistinct) { public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } + @Override + public Expression[] children() { return new Expression[]{ input }; } + @Override public String toString() { if (isDistinct) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java index 72ed83f86df6d..accdd1acd7d0e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java @@ -17,34 +17,30 @@ package org.apache.spark.sql.connector.expressions.filter; -import java.util.Objects; - import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; /** - * A filter that always evaluates to {@code false}. + * A predicate that always evaluates to {@code false}. * * @since 3.3.0 */ @Evolving -public final class AlwaysFalse extends Filter { +public final class AlwaysFalse extends Predicate implements Literal { - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - return true; + public AlwaysFalse() { + super("ALWAYS_FALSE", new Predicate[]{}); } - @Override - public int hashCode() { - return Objects.hash(); + public Boolean value() { + return false; } - @Override - public String toString() { return "FALSE"; } + public DataType dataType() { + return DataTypes.BooleanType; + } - @Override - public NamedReference[] references() { return EMPTY_REFERENCE; } + public String toString() { return "FALSE"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java index b6d39c3f64a77..5a14f64b9b7e2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java @@ -17,34 +17,30 @@ package org.apache.spark.sql.connector.expressions.filter; -import java.util.Objects; - import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; /** - * A filter that always evaluates to {@code true}. + * A predicate that always evaluates to {@code true}. * * @since 3.3.0 */ @Evolving -public final class AlwaysTrue extends Filter { +public final class AlwaysTrue extends Predicate implements Literal { + + public AlwaysTrue() { + super("ALWAYS_TRUE", new Predicate[]{}); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + public Boolean value() { return true; } - @Override - public int hashCode() { - return Objects.hash(); + public DataType dataType() { + return DataTypes.BooleanType; } - @Override public String toString() { return "TRUE"; } - - @Override - public NamedReference[] references() { return EMPTY_REFERENCE; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java index e0b8b13acb158..179a4b3c6349d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java @@ -20,20 +20,18 @@ import org.apache.spark.annotation.Evolving; /** - * A filter that evaluates to {@code true} iff both {@code left} and {@code right} evaluate to + * A predicate that evaluates to {@code true} iff both {@code left} and {@code right} evaluate to * {@code true}. * * @since 3.3.0 */ @Evolving -public final class And extends BinaryFilter { +public final class And extends Predicate { - public And(Filter left, Filter right) { - super(left, right); + public And(Predicate left, Predicate right) { + super("AND", new Predicate[]{left, right}); } - @Override - public String toString() { - return String.format("(%s) AND (%s)", left.describe(), right.describe()); - } + public Predicate left() { return (Predicate) children()[0]; } + public Predicate right() { return (Predicate) children()[1]; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java deleted file mode 100644 index 0ae6e5af3ca1a..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import java.util.Objects; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * Base class for {@link EqualNullSafe}, {@link EqualTo}, {@link GreaterThan}, - * {@link GreaterThanOrEqual}, {@link LessThan}, {@link LessThanOrEqual} - * - * @since 3.3.0 - */ -@Evolving -abstract class BinaryComparison extends Filter { - protected final NamedReference column; - protected final Literal value; - - protected BinaryComparison(NamedReference column, Literal value) { - this.column = column; - this.value = value; - } - - public NamedReference column() { return column; } - public Literal value() { return value; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - BinaryComparison that = (BinaryComparison) o; - return Objects.equals(column, that.column) && Objects.equals(value, that.value); - } - - @Override - public int hashCode() { - return Objects.hash(column, value); - } - - @Override - public NamedReference[] references() { return new NamedReference[] { column }; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryFilter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryFilter.java deleted file mode 100644 index ac4b9f281e9ca..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryFilter.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import java.util.Objects; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * Base class for {@link And}, {@link Or} - * - * @since 3.3.0 - */ -@Evolving -abstract class BinaryFilter extends Filter { - protected final Filter left; - protected final Filter right; - - protected BinaryFilter(Filter left, Filter right) { - this.left = left; - this.right = right; - } - - public Filter left() { return left; } - public Filter right() { return right; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - BinaryFilter and = (BinaryFilter) o; - return Objects.equals(left, and.left) && Objects.equals(right, and.right); - } - - @Override - public int hashCode() { - return Objects.hash(left, right); - } - - @Override - public NamedReference[] references() { - NamedReference[] refLeft = left.references(); - NamedReference[] refRight = right.references(); - NamedReference[] arr = new NamedReference[refLeft.length + refRight.length]; - System.arraycopy(refLeft, 0, arr, 0, refLeft.length); - System.arraycopy(refRight, 0, arr, refLeft.length, refRight.length); - return arr; - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java deleted file mode 100644 index 34b529194e075..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * Performs equality comparison, similar to {@link EqualTo}. However, this differs from - * {@link EqualTo} in that it returns {@code true} (rather than NULL) if both inputs are NULL, - * and {@code false} (rather than NULL) if one of the input is NULL and the other is not NULL. - * - * @since 3.3.0 - */ -@Evolving -public final class EqualNullSafe extends BinaryComparison { - - public EqualNullSafe(NamedReference column, Literal value) { - super(column, value); - } - - @Override - public String toString() { return this.column.describe() + " <=> " + value.describe(); } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java deleted file mode 100644 index b9c4fe053b83c..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value - * equal to {@code value}. - * - * @since 3.3.0 - */ -@Evolving -public final class EqualTo extends BinaryComparison { - - public EqualTo(NamedReference column, Literal value) { - super(column, value); - } - - @Override - public String toString() { return column.describe() + " = " + value.describe(); } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java deleted file mode 100644 index af87e76d2ff7d..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import java.io.Serializable; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Expression; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * Filter base class - * - * @since 3.3.0 - */ -@Evolving -public abstract class Filter implements Expression, Serializable { - - protected static final NamedReference[] EMPTY_REFERENCE = new NamedReference[0]; - - /** - * Returns list of columns that are referenced by this filter. - */ - public abstract NamedReference[] references(); -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java deleted file mode 100644 index a3374f359ea29..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value - * greater than {@code value}. - * - * @since 3.3.0 - */ -@Evolving -public final class GreaterThan extends BinaryComparison { - - public GreaterThan(NamedReference column, Literal value) { - super(column, value); - } - - @Override - public String toString() { return column.describe() + " > " + value.describe(); } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java deleted file mode 100644 index 4ee921014da41..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value - * greater than or equal to {@code value}. - * - * @since 3.3.0 - */ -@Evolving -public final class GreaterThanOrEqual extends BinaryComparison { - - public GreaterThanOrEqual(NamedReference column, Literal value) { - super(column, value); - } - - @Override - public String toString() { return column.describe() + " >= " + value.describe(); } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java deleted file mode 100644 index 8d6490b8984fd..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import java.util.Arrays; -import java.util.Objects; -import java.util.stream.Collectors; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to one of the - * {@code values} in the array. - * - * @since 3.3.0 - */ -@Evolving -public final class In extends Filter { - static final int MAX_LEN_TO_PRINT = 50; - private final NamedReference column; - private final Literal[] values; - - public In(NamedReference column, Literal[] values) { - this.column = column; - this.values = values; - } - - public NamedReference column() { return column; } - public Literal[] values() { return values; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - In in = (In) o; - return Objects.equals(column, in.column) && values.length == in.values.length - && Arrays.asList(values).containsAll(Arrays.asList(in.values)); - } - - @Override - public int hashCode() { - int result = Objects.hash(column); - result = 31 * result + Arrays.hashCode(values); - return result; - } - - @Override - public String toString() { - String res = Arrays.stream(values).limit((MAX_LEN_TO_PRINT)).map(Literal::describe) - .collect(Collectors.joining(", ")); - if (values.length > MAX_LEN_TO_PRINT) { - res += "..."; - } - return column.describe() + " IN (" + res + ")"; - } - - @Override - public NamedReference[] references() { return new NamedReference[] { column }; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java deleted file mode 100644 index 2cf000e99878e..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import java.util.Objects; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to a non-null value. - * - * @since 3.3.0 - */ -@Evolving -public final class IsNotNull extends Filter { - private final NamedReference column; - - public IsNotNull(NamedReference column) { - this.column = column; - } - - public NamedReference column() { return column; } - - @Override - public String toString() { return column.describe() + " IS NOT NULL"; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - IsNotNull isNotNull = (IsNotNull) o; - return Objects.equals(column, isNotNull.column); - } - - @Override - public int hashCode() { - return Objects.hash(column); - } - - @Override - public NamedReference[] references() { return new NamedReference[] { column }; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java deleted file mode 100644 index 1cd497c02242e..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import java.util.Objects; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to null. - * - * @since 3.3.0 - */ -@Evolving -public final class IsNull extends Filter { - private final NamedReference column; - - public IsNull(NamedReference column) { - this.column = column; - } - - public NamedReference column() { return column; } - - @Override - public String toString() { return column.describe() + " IS NULL"; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - IsNull isNull = (IsNull) o; - return Objects.equals(column, isNull.column); - } - - @Override - public int hashCode() { - return Objects.hash(column); - } - - @Override - public NamedReference[] references() { return new NamedReference[] { column }; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java deleted file mode 100644 index 9fa5cfb87f527..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value - * less than {@code value}. - * - * @since 3.3.0 - */ -@Evolving -public final class LessThan extends BinaryComparison { - - public LessThan(NamedReference column, Literal value) { - super(column, value); - } - - @Override - public String toString() { return column.describe() + " < " + value.describe(); } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java deleted file mode 100644 index a41b3c8045d5a..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value - * less than or equal to {@code value}. - * - * @since 3.3.0 - */ -@Evolving -public final class LessThanOrEqual extends BinaryComparison { - - public LessThanOrEqual(NamedReference column, Literal value) { - super(column, value); - } - - @Override - public String toString() { return column.describe() + " <= " + value.describe(); } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java index 69746f59ee933..d65c9f0b6c3d9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java @@ -17,40 +17,19 @@ package org.apache.spark.sql.connector.expressions.filter; -import java.util.Objects; - import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; /** - * A filter that evaluates to {@code true} iff {@code child} is evaluated to {@code false}. + * A predicate that evaluates to {@code true} iff {@code child} is evaluated to {@code false}. * * @since 3.3.0 */ @Evolving -public final class Not extends Filter { - private final Filter child; - - public Not(Filter child) { this.child = child; } - - public Filter child() { return child; } +public final class Not extends Predicate { - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Not not = (Not) o; - return Objects.equals(child, not.child); + public Not(Predicate child) { + super("NOT", new Predicate[]{child}); } - @Override - public int hashCode() { - return Objects.hash(child); - } - - @Override - public String toString() { return "NOT (" + child.describe() + ")"; } - - @Override - public NamedReference[] references() { return child.references(); } + public Predicate child() { return (Predicate) children()[0]; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java index baa33d849feef..7f1717cc7da58 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java @@ -20,20 +20,18 @@ import org.apache.spark.annotation.Evolving; /** - * A filter that evaluates to {@code true} iff at least one of {@code left} or {@code right} + * A predicate that evaluates to {@code true} iff at least one of {@code left} or {@code right} * evaluates to {@code true}. * * @since 3.3.0 */ @Evolving -public final class Or extends BinaryFilter { +public final class Or extends Predicate { - public Or(Filter left, Filter right) { - super(left, right); + public Or(Predicate left, Predicate right) { + super("OR", new Predicate[]{left, right}); } - @Override - public String toString() { - return String.format("(%s) OR (%s)", left.describe(), right.describe()); - } + public Predicate left() { return (Predicate) children()[0]; } + public Predicate right() { return (Predicate) children()[1]; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java new file mode 100644 index 0000000000000..e58cddc274c5f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; + +/** + * The general representation of predicate expressions, which contains the upper-cased expression + * name and all the children expressions. You can also use these concrete subclasses for better + * type safety: {@link And}, {@link Or}, {@link Not}, {@link AlwaysTrue}, {@link AlwaysFalse}. + *

        + * The currently supported predicate expressions: + *

          + *
        1. Name: IS_NULL + *
            + *
          • SQL semantic: expr IS NULL
          • + *
          • Since version: 3.3.0
          • + *
          + *
        2. + *
        3. Name: IS_NOT_NULL + *
            + *
          • SQL semantic: expr IS NOT NULL
          • + *
          • Since version: 3.3.0
          • + *
          + *
        4. + *
        5. Name: STARTS_WITH + *
            + *
          • SQL semantic: expr1 LIKE 'expr2%'
          • + *
          • Since version: 3.3.0
          • + *
          + *
        6. + *
        7. Name: ENDS_WITH + *
            + *
          • SQL semantic: expr1 LIKE '%expr2'
          • + *
          • Since version: 3.3.0
          • + *
          + *
        8. + *
        9. Name: CONTAINS + *
            + *
          • SQL semantic: expr1 LIKE '%expr2%'
          • + *
          • Since version: 3.3.0
          • + *
          + *
        10. + *
        11. Name: IN + *
            + *
          • SQL semantic: expr IN (expr1, expr2, ...)
          • + *
          • Since version: 3.3.0
          • + *
          + *
        12. + *
        13. Name: = + *
            + *
          • SQL semantic: expr1 = expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        14. + *
        15. Name: <> + *
            + *
          • SQL semantic: expr1 <> expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        16. + *
        17. Name: <=> + *
            + *
          • SQL semantic: null-safe version of expr1 = expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        18. + *
        19. Name: < + *
            + *
          • SQL semantic: expr1 < expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        20. + *
        21. Name: <= + *
            + *
          • SQL semantic: expr1 <= expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        22. + *
        23. Name: > + *
            + *
          • SQL semantic: expr1 > expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        24. + *
        25. Name: >= + *
            + *
          • SQL semantic: expr1 >= expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        26. + *
        27. Name: AND + *
            + *
          • SQL semantic: expr1 AND expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        28. + *
        29. Name: OR + *
            + *
          • SQL semantic: expr1 OR expr2
          • + *
          • Since version: 3.3.0
          • + *
          + *
        30. + *
        31. Name: NOT + *
            + *
          • SQL semantic: NOT expr
          • + *
          • Since version: 3.3.0
          • + *
          + *
        32. + *
        33. Name: ALWAYS_TRUE + *
            + *
          • SQL semantic: TRUE
          • + *
          • Since version: 3.3.0
          • + *
          + *
        34. + *
        35. Name: ALWAYS_FALSE + *
            + *
          • SQL semantic: FALSE
          • + *
          • Since version: 3.3.0
          • + *
          + *
        36. + *
        + * + * @since 3.3.0 + */ +@Evolving +public class Predicate extends GeneralScalarExpression { + + public Predicate(String name, Expression[] children) { + super(name, children); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java deleted file mode 100644 index 9a01e4d574888..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to - * a string that contains {@code value}. - * - * @since 3.3.0 - */ -@Evolving -public final class StringContains extends StringPredicate { - - public StringContains(NamedReference column, UTF8String value) { - super(column, value); - } - - @Override - public String toString() { return "STRING_CONTAINS(" + column.describe() + ", " + value + ")"; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java deleted file mode 100644 index 11b8317ba4895..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to - * a string that ends with {@code value}. - * - * @since 3.3.0 - */ -@Evolving -public final class StringEndsWith extends StringPredicate { - - public StringEndsWith(NamedReference column, UTF8String value) { - super(column, value); - } - - @Override - public String toString() { return "STRING_ENDS_WITH(" + column.describe() + ", " + value + ")"; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java deleted file mode 100644 index ffe5d5dba45b3..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import java.util.Objects; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * Base class for {@link StringContains}, {@link StringStartsWith}, - * {@link StringEndsWith} - * - * @since 3.3.0 - */ -@Evolving -abstract class StringPredicate extends Filter { - protected final NamedReference column; - protected final UTF8String value; - - protected StringPredicate(NamedReference column, UTF8String value) { - this.column = column; - this.value = value; - } - - public NamedReference column() { return column; } - public UTF8String value() { return value; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - StringPredicate that = (StringPredicate) o; - return Objects.equals(column, that.column) && Objects.equals(value, that.value); - } - - @Override - public int hashCode() { - return Objects.hash(column, value); - } - - @Override - public NamedReference[] references() { return new NamedReference[] { column }; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java deleted file mode 100644 index 38a5de1921cdc..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions.filter; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A filter that evaluates to {@code true} iff the {@code column} evaluates to - * a string that starts with {@code value}. - * - * @since 3.3.0 - */ -@Evolving -public final class StringStartsWith extends StringPredicate { - - public StringStartsWith(NamedReference column, UTF8String value) { - super(column, value); - } - - @Override - public String toString() { - return "STRING_STARTS_WITH(" + column.describe() + ", " + value + ")"; - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java index 1ba9939dd0849..1fec939aeb474 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java @@ -18,11 +18,14 @@ package org.apache.spark.sql.connector.read; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.filter.Filter; +import org.apache.spark.sql.connector.expressions.filter.Predicate; /** * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to - * push down filters to the data source and reduce the size of the data to be read. + * push down V2 {@link Predicate} to the data source and reduce the size of the data to be read. + * Please Note that this interface is preferred over {@link SupportsPushDownFilters}, which uses + * V1 {@link org.apache.spark.sql.sources.Filter} and is less efficient due to the + * internal -> external data conversion. * * @since 3.3.0 */ @@ -30,28 +33,31 @@ public interface SupportsPushDownV2Filters extends ScanBuilder { /** - * Pushes down filters, and returns filters that need to be evaluated after scanning. + * Pushes down predicates, and returns predicates that need to be evaluated after scanning. *

        - * Rows should be returned from the data source if and only if all of the filters match. That is, - * filters must be interpreted as ANDed together. + * Rows should be returned from the data source if and only if all of the predicates match. + * That is, predicates must be interpreted as ANDed together. */ - Filter[] pushFilters(Filter[] filters); + Predicate[] pushPredicates(Predicate[] predicates); /** - * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * Returns the predicates that are pushed to the data source via + * {@link #pushPredicates(Predicate[])}. *

        - * There are 3 kinds of filters: + * There are 3 kinds of predicates: *

          - *
        1. pushable filters which don't need to be evaluated again after scanning.
        2. - *
        3. pushable filters which still need to be evaluated after scanning, e.g. parquet row - * group filter.
        4. - *
        5. non-pushable filters.
        6. + *
        7. pushable predicates which don't need to be evaluated again after scanning.
        8. + *
        9. pushable predicates which still need to be evaluated after scanning, e.g. parquet row + * group predicate.
        10. + *
        11. non-pushable predicates.
        12. *
        *

        - * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * Both case 1 and 2 should be considered as pushed predicates and should be returned + * by this method. *

        - * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} - * is never called, empty array should be returned for this case. + * It's possible that there is no predicates in the query and + * {@link #pushPredicates(Predicate[])} is never called, + * empty array should be returned for this case. */ - Filter[] pushedFilters(); + Predicate[] pushedPredicates(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 0af0d88b0f622..91dae749f974b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -17,39 +17,53 @@ package org.apache.spark.sql.connector.util; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import org.apache.spark.sql.connector.expressions.Expression; -import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; -import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.Literal; /** * The builder to generate SQL from V2 expressions. */ public class V2ExpressionSQLBuilder { + public String build(Expression expr) { - if (expr instanceof LiteralValue) { - return visitLiteral((LiteralValue) expr); - } else if (expr instanceof FieldReference) { - return visitFieldReference((FieldReference) expr); + if (expr instanceof Literal) { + return visitLiteral((Literal) expr); + } else if (expr instanceof NamedReference) { + return visitNamedReference((NamedReference) expr); } else if (expr instanceof GeneralScalarExpression) { GeneralScalarExpression e = (GeneralScalarExpression) expr; String name = e.name(); switch (name) { + case "IN": { + List children = + Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); + return visitIn(children.get(0), children.subList(1, children.size())); + } case "IS_NULL": return visitIsNull(build(e.children()[0])); case "IS_NOT_NULL": return visitIsNotNull(build(e.children()[0])); + case "STARTS_WITH": + return visitStartsWith(build(e.children()[0]), build(e.children()[1])); + case "ENDS_WITH": + return visitEndsWith(build(e.children()[0]), build(e.children()[1])); + case "CONTAINS": + return visitContains(build(e.children()[0]), build(e.children()[1])); case "=": - case "!=": + case "<>": case "<=>": case "<": case "<=": case ">": case ">=": - return visitBinaryComparison(name, build(e.children()[0]), build(e.children()[1])); + return visitBinaryComparison( + name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); case "+": case "*": case "/": @@ -57,12 +71,14 @@ public String build(Expression expr) { case "&": case "|": case "^": - return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1])); + return visitBinaryArithmetic( + name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); case "-": if (e.children().length == 1) { return visitUnaryArithmetic(name, build(e.children()[0])); } else { - return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1])); + return visitBinaryArithmetic( + name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); } case "AND": return visitAnd(name, build(e.children()[0]), build(e.children()[1])); @@ -72,12 +88,11 @@ public String build(Expression expr) { return visitNot(build(e.children()[0])); case "~": return visitUnaryArithmetic(name, build(e.children()[0])); - case "CASE_WHEN": - List children = new ArrayList<>(); - for (Expression child : e.children()) { - children.add(build(child)); - } + case "CASE_WHEN": { + List children = + Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); return visitCaseWhen(children.toArray(new String[e.children().length])); + } // TODO supports other expressions default: return visitUnexpectedExpr(expr); @@ -87,12 +102,19 @@ public String build(Expression expr) { } } - protected String visitLiteral(LiteralValue literalValue) { - return literalValue.toString(); + protected String visitLiteral(Literal literal) { + return literal.toString(); } - protected String visitFieldReference(FieldReference fieldRef) { - return fieldRef.toString(); + protected String visitNamedReference(NamedReference namedRef) { + return namedRef.toString(); + } + + protected String visitIn(String v, List list) { + if (list.isEmpty()) { + return "CASE WHEN " + v + " IS NULL THEN NULL ELSE FALSE END"; + } + return v + " IN (" + list.stream().collect(Collectors.joining(", ")) + ")"; } protected String visitIsNull(String v) { @@ -103,12 +125,46 @@ protected String visitIsNotNull(String v) { return v + " IS NOT NULL"; } + protected String visitStartsWith(String l, String r) { + // Remove quotes at the beginning and end. + // e.g. converts "'str'" to "str". + String value = r.substring(1, r.length() - 1); + return l + " LIKE '" + value + "%'"; + } + + protected String visitEndsWith(String l, String r) { + // Remove quotes at the beginning and end. + // e.g. converts "'str'" to "str". + String value = r.substring(1, r.length() - 1); + return l + " LIKE '%" + value + "'"; + } + + protected String visitContains(String l, String r) { + // Remove quotes at the beginning and end. + // e.g. converts "'str'" to "str". + String value = r.substring(1, r.length() - 1); + return l + " LIKE '%" + value + "%'"; + } + + private String inputToSQL(Expression input) { + if (input.children().length > 1) { + return "(" + build(input) + ")"; + } else { + return build(input); + } + } + protected String visitBinaryComparison(String name, String l, String r) { - return "(" + l + ") " + name + " (" + r + ")"; + switch (name) { + case "<=>": + return "(" + l + " = " + r + ") OR (" + l + " IS NULL AND " + r + " IS NULL)"; + default: + return l + " " + name + " " + r; + } } protected String visitBinaryArithmetic(String name, String l, String r) { - return "(" + l + ") " + name + " (" + r + ")"; + return l + " " + name + " " + r; } protected String visitAnd(String name, String l, String r) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 80658f7cec2e3..c4d9be95e97ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -18,7 +18,12 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Evolving, Stable} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, Predicate} +import org.apache.spark.sql.types.StringType +import org.apache.spark.unsafe.types.UTF8String //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -64,6 +69,11 @@ sealed abstract class Filter { private[sql] def containsNestedColumn: Boolean = { this.v2references.exists(_.length > 1) } + + /** + * Converts V1 filter to V2 filter + */ + private[sql] def toV2: Predicate } /** @@ -78,6 +88,11 @@ sealed abstract class Filter { @Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("=", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -93,6 +108,11 @@ case class EqualTo(attribute: String, value: Any) extends Filter { @Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("<=>", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -107,6 +127,11 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { @Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate(">", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -121,6 +146,11 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { @Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate(">=", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -135,6 +165,11 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { @Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("<", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -149,6 +184,11 @@ case class LessThan(attribute: String, value: Any) extends Filter { @Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("<=", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -185,6 +225,13 @@ case class In(attribute: String, values: Array[Any]) extends Filter { } override def references: Array[String] = Array(attribute) ++ values.flatMap(findReferences) + override def toV2: Predicate = { + val literals = values.map { value => + val literal = Literal(value) + LiteralValue(literal.value, literal.dataType) + } + new Predicate("IN", FieldReference(attribute) +: literals) + } } /** @@ -198,6 +245,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter { @Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("IS_NULL", Array(FieldReference(attribute))) } /** @@ -211,6 +259,7 @@ case class IsNull(attribute: String) extends Filter { @Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("IS_NOT_NULL", Array(FieldReference(attribute))) } /** @@ -221,6 +270,7 @@ case class IsNotNull(attribute: String) extends Filter { @Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references + override def toV2: Predicate = new Predicate("AND", Seq(left, right).map(_.toV2).toArray) } /** @@ -231,6 +281,7 @@ case class And(left: Filter, right: Filter) extends Filter { @Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references + override def toV2: Predicate = new Predicate("OR", Seq(left, right).map(_.toV2).toArray) } /** @@ -241,6 +292,7 @@ case class Or(left: Filter, right: Filter) extends Filter { @Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references + override def toV2: Predicate = new Predicate("NOT", Array(child.toV2)) } /** @@ -255,6 +307,8 @@ case class Not(child: Filter) extends Filter { @Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("STARTS_WITH", + Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType))) } /** @@ -269,6 +323,8 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { @Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("ENDS_WITH", + Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType))) } /** @@ -283,6 +339,8 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { @Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("CONTAINS", + Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType))) } /** @@ -293,6 +351,7 @@ case class StringContains(attribute: String, value: String) extends Filter { @Evolving case class AlwaysTrue() extends Filter { override def references: Array[String] = Array.empty + override def toV2: Predicate = new V2AlwaysTrue() } @Evolving @@ -307,6 +366,7 @@ object AlwaysTrue extends AlwaysTrue { @Evolving case class AlwaysFalse() extends Filter { override def references: Array[String] = Array.empty + override def toV2: Predicate = new V2AlwaysFalse() } @Evolving diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index b2371ce667ffc..4a50e063bee68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -44,7 +44,6 @@ class TransformExtractorSuite extends SparkFunSuite { */ private def transform(func: String, ref: NamedReference): Transform = new Transform { override def name: String = func - override def references: Array[NamedReference] = Array(ref) override def arguments: Array[Expression] = Array(ref) override def toString: String = ref.describe } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 2ffae68284cc0..a04e6470f6bf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,15 +17,21 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} +import org.apache.spark.sql.execution.datasources.PushableColumn +import org.apache.spark.sql.types.BooleanType /** * The builder to generate V2 expressions from catalyst expressions. */ -class V2ExpressionBuilder(e: Expression) { +class V2ExpressionBuilder( + e: Expression, nestedPredicatePushdownEnabled: Boolean = false, isPredicate: Boolean = false) { - def build(): Option[V2Expression] = generateExpression(e) + val pushableColumn = PushableColumn(nestedPredicatePushdownEnabled) + + def build(): Option[V2Expression] = generateExpression(e, isPredicate) private def canTranslate(b: BinaryOperator) = b match { case _: And | _: Or => true @@ -39,18 +45,83 @@ class V2ExpressionBuilder(e: Expression) { case _ => false } - private def generateExpression(expr: Expression): Option[V2Expression] = expr match { + private def generateExpression( + expr: Expression, isPredicate: Boolean = false): Option[V2Expression] = expr match { + case Literal(true, BooleanType) => Some(new AlwaysTrue()) + case Literal(false, BooleanType) => Some(new AlwaysFalse()) case Literal(value, dataType) => Some(LiteralValue(value, dataType)) - case attr: Attribute => Some(FieldReference(attr.name)) + case col @ pushableColumn(name) if nestedPredicatePushdownEnabled => + if (isPredicate && col.dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate("=", Array(FieldReference(name), LiteralValue(true, BooleanType)))) + } else { + Some(FieldReference(name)) + } + case pushableColumn(name) if !nestedPredicatePushdownEnabled => + Some(FieldReference(name)) + case in @ InSet(child, hset) => + generateExpression(child).map { v => + val children = + (v +: hset.toSeq.map(elem => LiteralValue(elem, in.dataType))).toArray[V2Expression] + new V2Predicate("IN", children) + } + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case In(value, list) => + val v = generateExpression(value) + val listExpressions = list.flatMap(generateExpression(_)) + if (v.isDefined && list.length == listExpressions.length) { + val children = (v.get +: listExpressions).toArray[V2Expression] + // The children looks like [expr, value1, ..., valueN] + Some(new V2Predicate("IN", children)) + } else { + None + } case IsNull(col) => generateExpression(col) - .map(c => new GeneralScalarExpression("IS_NULL", Array[V2Expression](c))) + .map(c => new V2Predicate("IS_NULL", Array[V2Expression](c))) case IsNotNull(col) => generateExpression(col) - .map(c => new GeneralScalarExpression("IS_NOT_NULL", Array[V2Expression](c))) - case b: BinaryOperator if canTranslate(b) => - val left = generateExpression(b.left) - val right = generateExpression(b.right) + .map(c => new V2Predicate("IS_NOT_NULL", Array[V2Expression](c))) + case p: StringPredicate => + val left = generateExpression(p.left) + val right = generateExpression(p.right) if (left.isDefined && right.isDefined) { - Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](left.get, right.get))) + val name = p match { + case _: StartsWith => "STARTS_WITH" + case _: EndsWith => "ENDS_WITH" + case _: Contains => "CONTAINS" + } + Some(new V2Predicate(name, Array[V2Expression](left.get, right.get))) + } else { + None + } + case and: And => + val l = generateExpression(and.left, true) + val r = generateExpression(and.right, true) + if (l.isDefined && r.isDefined) { + assert(l.get.isInstanceOf[V2Predicate] && r.get.isInstanceOf[V2Predicate]) + Some(new V2And(l.get.asInstanceOf[V2Predicate], r.get.asInstanceOf[V2Predicate])) + } else { + None + } + case or: Or => + val l = generateExpression(or.left, true) + val r = generateExpression(or.right, true) + if (l.isDefined && r.isDefined) { + assert(l.get.isInstanceOf[V2Predicate] && r.get.isInstanceOf[V2Predicate]) + Some(new V2Or(l.get.asInstanceOf[V2Predicate], r.get.asInstanceOf[V2Predicate])) + } else { + None + } + case b: BinaryOperator if canTranslate(b) => + val l = generateExpression(b.left) + val r = generateExpression(b.right) + if (l.isDefined && r.isDefined) { + b match { + case _: Predicate => + Some(new V2Predicate(b.sqlOperator, Array[V2Expression](l.get, r.get))) + case _ => + Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](l.get, r.get))) + } } else { None } @@ -58,32 +129,35 @@ class V2ExpressionBuilder(e: Expression) { val left = generateExpression(eq.left) val right = generateExpression(eq.right) if (left.isDefined && right.isDefined) { - Some(new GeneralScalarExpression("!=", Array[V2Expression](left.get, right.get))) + Some(new V2Predicate("<>", Array[V2Expression](left.get, right.get))) } else { None } - case Not(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("NOT", Array[V2Expression](v))) + case Not(child) => generateExpression(child, true) // NOT expects predicate + .map { v => + assert(v.isInstanceOf[V2Predicate]) + new V2Not(v.asInstanceOf[V2Predicate]) + } case UnaryMinus(child, true) => generateExpression(child) .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) case BitwiseNot(child) => generateExpression(child) .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) case CaseWhen(branches, elseValue) => - val conditions = branches.map(_._1).flatMap(generateExpression) - val values = branches.map(_._2).flatMap(generateExpression) + val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) + val values = branches.map(_._2).flatMap(generateExpression(_, true)) if (conditions.length == branches.length && values.length == branches.length) { val branchExpressions = conditions.zip(values).flatMap { case (c, v) => Seq[V2Expression](c, v) } if (elseValue.isDefined) { - elseValue.flatMap(generateExpression).map { v => + elseValue.flatMap(generateExpression(_)).map { v => val children = (branchExpressions :+ v).toArray[V2Expression] // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue] - new GeneralScalarExpression("CASE_WHEN", children) + new V2Predicate("CASE_WHEN", children) } } else { // The children looks like [condition1, value1, ..., conditionN, valueN] - Some(new GeneralScalarExpression("CASE_WHEN", branchExpressions.toArray[V2Expression])) + Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression])) } } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index ae444bf3aabf4..bb5b8e32aef63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -152,9 +152,14 @@ case class RowDataSourceScanExec( pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") } - Map( - "ReadSchema" -> requiredSchema.catalogString, - "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ + val pushedFilters = if (pushedDownOperators.pushedPredicates.nonEmpty) { + seqToString(pushedDownOperators.pushedPredicates.map(_.describe())) + } else { + seqToString(markedFilters.toSeq) + } + + Map("ReadSchema" -> requiredSchema.catalogString, + "PushedFilters" -> pushedFilters) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), "PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 5d0aecb94264d..408da524cbb04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -337,7 +337,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - PushedDownOperators(None, None, None, Seq.empty), + PushedDownOperators(None, None, None, Seq.empty, Seq.empty), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -411,7 +411,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None, Seq.empty), + PushedDownOperators(None, None, None, Seq.empty, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -434,7 +434,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None, Seq.empty), + PushedDownOperators(None, None, None, Seq.empty, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -858,8 +858,5 @@ object PushableColumnWithoutNestedColumn extends PushableColumnBase { * Get the expression of DS V2 to represent catalyst expression that can be pushed down. */ object PushableExpression { - def unapply(e: Expression): Option[V2Expression] = e match { - case PushableColumnWithoutNestedColumn(name) => Some(FieldReference(name)) - case _ => new V2ExpressionBuilder(e).build() - } + def unapply(e: Expression): Option[V2Expression] = new V2ExpressionBuilder(e).build() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index b5224eaf7262b..b30b460ac67db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -26,9 +26,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} -import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator @@ -92,61 +92,13 @@ object JDBCRDD extends Logging { new StructType(columns.map(name => fieldMap(name))) } - /** - * Turns a single Filter into a String representing a SQL expression. - * Returns None for an unhandled filter. - */ - def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = { - def quote(colName: String): String = dialect.quoteIdentifier(colName) - - Option(f match { - case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}" - case EqualNullSafe(attr, value) => - val col = quote(attr) - s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " + - s"${dialect.compileValue(value)} IS NULL) OR " + - s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))" - case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}" - case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}" - case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}" - case IsNull(attr) => s"${quote(attr)} IS NULL" - case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL" - case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'" - case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'" - case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'" - case In(attr, value) if value.isEmpty => - s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END" - case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})" - case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null) - case Or(f1, f2) => - // We can't compile Or filter unless both sub-filters are compiled successfully. - // It applies too for the following And filter. - // If we can make sure compileFilter supports all filters, we can remove this check. - val or = Seq(f1, f2).flatMap(compileFilter(_, dialect)) - if (or.size == 2) { - or.map(p => s"($p)").mkString(" OR ") - } else { - null - } - case And(f1, f2) => - val and = Seq(f1, f2).flatMap(compileFilter(_, dialect)) - if (and.size == 2) { - and.map(p => s"($p)").mkString(" AND ") - } else { - null - } - case _ => null - }) - } - /** * Build and return JDBCRDD from the given information. * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. * @param requiredColumns - The names of the columns or aggregate columns to SELECT. - * @param filters - The filters to include in all WHERE clauses. + * @param predicates - The predicates to include in all WHERE clauses. * @param parts - An array of JDBCPartitions specifying partition ids and * per-partition WHERE clauses. * @param options - JDBC options that contains url, table and other information. @@ -155,7 +107,7 @@ object JDBCRDD extends Logging { * @param sample - The pushed down tableSample. * @param limit - The pushed down limit. If the value is 0, it means no limit or limit * is not pushed down. - * @param sortValues - The sort values cooperates with limit to realize top N. + * @param sortOrders - The sort orders cooperates with limit to realize top N. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ @@ -164,7 +116,7 @@ object JDBCRDD extends Logging { sc: SparkContext, schema: StructType, requiredColumns: Array[String], - filters: Array[Filter], + predicates: Array[Predicate], parts: Array[Partition], options: JDBCOptions, outputSchema: Option[StructType] = None, @@ -185,7 +137,7 @@ object JDBCRDD extends Logging { dialect.createConnectionFactory(options), outputSchema.getOrElse(pruneSchema(schema, requiredColumns)), quotedColumns, - filters, + predicates, parts, url, options, @@ -207,7 +159,7 @@ private[jdbc] class JDBCRDD( getConnection: Int => Connection, schema: StructType, columns: Array[String], - filters: Array[Filter], + predicates: Array[Predicate], partitions: Array[Partition], url: String, options: JDBCOptions, @@ -230,10 +182,10 @@ private[jdbc] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = - filters - .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url))) - .map(p => s"($p)").mkString(" AND ") + private val filterWhereClause: String = { + val dialect = JdbcDialects.get(url) + predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ") + } /** * A WHERE clause representing both `filters`, if any, and the current partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index ecb207363cd59..0f1a1b6dc667b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf @@ -270,10 +271,11 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - // Check if JDBCRDD.compileFilter can accept input filters + // Check if JdbcDialect can compile input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { - filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + val dialect = JdbcDialects.get(jdbcOptions.url) + filters.filter(f => dialect.compileExpression(f.toV2).isEmpty) } else { filters } @@ -281,17 +283,17 @@ private[sql] case class JDBCRelation( override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { // When pushDownPredicate is false, all Filters that need to be pushed down should be ignored - val pushedFilters = if (jdbcOptions.pushDownPredicate) { - filters + val pushedPredicates = if (jdbcOptions.pushDownPredicate) { + filters.map(_.toV2) } else { - Array.empty[Filter] + Array.empty[Predicate] } // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, schema, requiredColumns, - pushedFilters, + pushedPredicates, parts, jdbcOptions).asInstanceOf[RDD[Row]] } @@ -299,7 +301,7 @@ private[sql] case class JDBCRelation( def buildScan( requiredColumns: Array[String], finalSchema: StructType, - filters: Array[Filter], + predicates: Array[Predicate], groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], limit: Int, @@ -309,7 +311,7 @@ private[sql] case class JDBCRelation( sparkSession.sparkContext, schema, requiredColumns, - filters, + predicates, parts, jdbcOptions, Some(finalSchema), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d07c29e080265..f267a03cbe218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,25 +23,22 @@ import scala.collection.mutable import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, EmptyRow, Expression, Literal, NamedExpression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog} -import org.apache.spark.sql.connector.expressions.{FieldReference, Literal => V2Literal, LiteralValue} -import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, Filter => V2Filter, GreaterThan => V2GreaterThan, GreaterThanOrEqual => V2GreaterThanOrEqual, In => V2In, IsNotNull => V2IsNotNull, IsNull => V2IsNull, LessThan => V2LessThan, LessThanOrEqual => V2LessThanOrEqual, Not => V2Not, Or => V2Or, StringContains => V2StringContains, StringEndsWith => V2StringEndsWith, StringStartsWith => V2StringStartsWith} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn, PushableColumnBase} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.{BaseRelation, TableScan} -import org.apache.spark.sql.types.{BooleanType, StringType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.StorageLevel -import org.apache.spark.unsafe.types.UTF8String class DataSourceV2Strategy(session: SparkSession) extends Strategy with PredicateHelper { @@ -439,71 +436,12 @@ private[sql] object DataSourceV2Strategy { private def translateLeafNodeFilterV2( predicate: Expression, - pushableColumn: PushableColumnBase): Option[V2Filter] = predicate match { - case expressions.EqualTo(pushableColumn(name), Literal(v, t)) => - Some(new V2EqualTo(FieldReference(name), LiteralValue(v, t))) - case expressions.EqualTo(Literal(v, t), pushableColumn(name)) => - Some(new V2EqualTo(FieldReference(name), LiteralValue(v, t))) - - case expressions.EqualNullSafe(pushableColumn(name), Literal(v, t)) => - Some(new V2EqualNullSafe(FieldReference(name), LiteralValue(v, t))) - case expressions.EqualNullSafe(Literal(v, t), pushableColumn(name)) => - Some(new V2EqualNullSafe(FieldReference(name), LiteralValue(v, t))) - - case expressions.GreaterThan(pushableColumn(name), Literal(v, t)) => - Some(new V2GreaterThan(FieldReference(name), LiteralValue(v, t))) - case expressions.GreaterThan(Literal(v, t), pushableColumn(name)) => - Some(new V2LessThan(FieldReference(name), LiteralValue(v, t))) - - case expressions.LessThan(pushableColumn(name), Literal(v, t)) => - Some(new V2LessThan(FieldReference(name), LiteralValue(v, t))) - case expressions.LessThan(Literal(v, t), pushableColumn(name)) => - Some(new V2GreaterThan(FieldReference(name), LiteralValue(v, t))) - - case expressions.GreaterThanOrEqual(pushableColumn(name), Literal(v, t)) => - Some(new V2GreaterThanOrEqual(FieldReference(name), LiteralValue(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), pushableColumn(name)) => - Some(new V2LessThanOrEqual(FieldReference(name), LiteralValue(v, t))) - - case expressions.LessThanOrEqual(pushableColumn(name), Literal(v, t)) => - Some(new V2LessThanOrEqual(FieldReference(name), LiteralValue(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), pushableColumn(name)) => - Some(new V2GreaterThanOrEqual(FieldReference(name), LiteralValue(v, t))) - - case in @ expressions.InSet(pushableColumn(name), set) => - val values: Array[V2Literal[_]] = - set.toSeq.map(elem => LiteralValue(elem, in.dataType)).toArray - Some(new V2In(FieldReference(name), values)) - - // Because we only convert In to InSet in Optimizer when there are more than certain - // items. So it is possible we still get an In expression here that needs to be pushed - // down. - case in @ expressions.In(pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => - val hSet = list.map(_.eval(EmptyRow)) - Some(new V2In(FieldReference(name), - hSet.toArray.map(LiteralValue(_, in.value.dataType)))) - - case expressions.IsNull(pushableColumn(name)) => - Some(new V2IsNull(FieldReference(name))) - case expressions.IsNotNull(pushableColumn(name)) => - Some(new V2IsNotNull(FieldReference(name))) - - case expressions.StartsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(new V2StringStartsWith(FieldReference(name), v)) - - case expressions.EndsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(new V2StringEndsWith(FieldReference(name), v)) - - case expressions.Contains(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(new V2StringContains(FieldReference(name), v)) - - case expressions.Literal(true, BooleanType) => - Some(new V2AlwaysTrue) - - case expressions.Literal(false, BooleanType) => - Some(new V2AlwaysFalse) - - case _ => None + supportNestedPredicatePushdown: Boolean): Option[Predicate] = { + val pushablePredicate = PushablePredicate(supportNestedPredicatePushdown) + predicate match { + case pushablePredicate(expr) => Some(expr) + case _ => None + } } /** @@ -513,7 +451,7 @@ private[sql] object DataSourceV2Strategy { */ protected[sql] def translateFilterV2( predicate: Expression, - supportNestedPredicatePushdown: Boolean): Option[V2Filter] = { + supportNestedPredicatePushdown: Boolean): Option[Predicate] = { translateFilterV2WithMapping(predicate, None, supportNestedPredicatePushdown) } @@ -528,11 +466,11 @@ private[sql] object DataSourceV2Strategy { */ protected[sql] def translateFilterV2WithMapping( predicate: Expression, - translatedFilterToExpr: Option[mutable.HashMap[V2Filter, Expression]], + translatedFilterToExpr: Option[mutable.HashMap[Predicate, Expression]], nestedPredicatePushdownEnabled: Boolean) - : Option[V2Filter] = { + : Option[Predicate] = { predicate match { - case expressions.And(left, right) => + case And(left, right) => // See SPARK-12218 for detailed discussion // It is not safe to just convert one side if we do not understand the // other side. Here is an example used to explain the reason. @@ -549,7 +487,7 @@ private[sql] object DataSourceV2Strategy { right, translatedFilterToExpr, nestedPredicatePushdownEnabled) } yield new V2And(leftFilter, rightFilter) - case expressions.Or(left, right) => + case Or(left, right) => for { leftFilter <- translateFilterV2WithMapping( left, translatedFilterToExpr, nestedPredicatePushdownEnabled) @@ -557,13 +495,12 @@ private[sql] object DataSourceV2Strategy { right, translatedFilterToExpr, nestedPredicatePushdownEnabled) } yield new V2Or(leftFilter, rightFilter) - case expressions.Not(child) => + case Not(child) => translateFilterV2WithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) .map(new V2Not(_)) case other => - val filter = translateLeafNodeFilterV2( - other, PushableColumn(nestedPredicatePushdownEnabled)) + val filter = translateLeafNodeFilterV2(other, nestedPredicatePushdownEnabled) if (filter.isDefined && translatedFilterToExpr.isDefined) { translatedFilterToExpr.get(filter.get) = predicate } @@ -572,20 +509,34 @@ private[sql] object DataSourceV2Strategy { } protected[sql] def rebuildExpressionFromFilter( - filter: V2Filter, - translatedFilterToExpr: mutable.HashMap[V2Filter, Expression]): Expression = { - filter match { + predicate: Predicate, + translatedFilterToExpr: mutable.HashMap[Predicate, Expression]): Expression = { + predicate match { case and: V2And => - expressions.And(rebuildExpressionFromFilter(and.left, translatedFilterToExpr), - rebuildExpressionFromFilter(and.right, translatedFilterToExpr)) + expressions.And( + rebuildExpressionFromFilter(and.left(), translatedFilterToExpr), + rebuildExpressionFromFilter(and.right(), translatedFilterToExpr)) case or: V2Or => - expressions.Or(rebuildExpressionFromFilter(or.left, translatedFilterToExpr), - rebuildExpressionFromFilter(or.right, translatedFilterToExpr)) + expressions.Or( + rebuildExpressionFromFilter(or.left(), translatedFilterToExpr), + rebuildExpressionFromFilter(or.right(), translatedFilterToExpr)) case not: V2Not => - expressions.Not(rebuildExpressionFromFilter(not.child, translatedFilterToExpr)) - case other => - translatedFilterToExpr.getOrElse(other, - throw new IllegalStateException("Failed to rebuild Expression for filter: " + filter)) + expressions.Not(rebuildExpressionFromFilter(not.child(), translatedFilterToExpr)) + case _ => + translatedFilterToExpr.getOrElse(predicate, + throw new IllegalStateException("Failed to rebuild Expression for filter: " + predicate)) } } } + +/** + * Get the expression of DS V2 to represent catalyst predicate that can be pushed down. + */ +case class PushablePredicate(nestedPredicatePushdownEnabled: Boolean) { + + def unapply(e: Expression): Option[Predicate] = + new V2ExpressionBuilder(e, nestedPredicatePushdownEnabled, true).build().map { v => + assert(v.isInstanceOf[Predicate]) + v.asInstanceOf[Predicate] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 2bffa761dd9e9..2f55b7ee46ac7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.SortOrder -import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf @@ -35,9 +35,8 @@ object PushDownUtils extends PredicateHelper { * * @return pushed filter and post-scan filters. */ - def pushFilters( - scanBuilder: ScanBuilder, - filters: Seq[Expression]): (Either[Seq[sources.Filter], Seq[V2Filter]], Seq[Expression]) = { + def pushFilters(scanBuilder: ScanBuilder, filters: Seq[Expression]) + : (Either[Seq[sources.Filter], Seq[Predicate]], Seq[Expression]) = { scanBuilder match { case r: SupportsPushDownFilters => // A map from translated data source leaf node filters to original catalyst filter @@ -73,8 +72,8 @@ object PushDownUtils extends PredicateHelper { // expressions. For a `And`/`Or` predicate, it is possible that the predicate is partially // pushed down. This map can be used to construct a catalyst filter expression from the // input filter, or a superset(partial push down filter) of the input filter. - val translatedFilterToExpr = mutable.HashMap.empty[V2Filter, Expression] - val translatedFilters = mutable.ArrayBuffer.empty[V2Filter] + val translatedFilterToExpr = mutable.HashMap.empty[Predicate, Expression] + val translatedFilters = mutable.ArrayBuffer.empty[Predicate] // Catalyst filter expression that can't be translated to data source filters. val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] @@ -92,10 +91,10 @@ object PushDownUtils extends PredicateHelper { // Data source filters that need to be evaluated again after scanning. which means // the data source cannot guarantee the rows returned can pass these filters. // As a result we must return it so Spark can plan an extra filter operator. - val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter => - DataSourceV2Strategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) + val postScanFilters = r.pushPredicates(translatedFilters.toArray).map { predicate => + DataSourceV2Strategy.rebuildExpressionFromFilter(predicate, translatedFilterToExpr) } - (Right(r.pushedFilters), (untranslatableExprs ++ postScanFilters).toSeq) + (Right(r.pushedPredicates), (untranslatableExprs ++ postScanFilters).toSeq) case f: FileScanBuilder => val postScanFilters = f.pushFilters(filters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala index 20ced9c17f7e0..a95b4593fc397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.expressions.filter.Predicate /** * Pushed down operators @@ -27,6 +28,7 @@ case class PushedDownOperators( aggregation: Option[Aggregation], sample: Option[TableSampleInfo], limit: Option[Int], - sortValues: Seq[SortOrder]) { + sortValues: Seq[SortOrder], + pushedPredicates: Seq[Predicate]) { assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index b97823fcd09e8..171110cc027ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources @@ -63,6 +64,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val pushedFiltersStr = if (pushedFilters.isLeft) { pushedFilters.left.get.mkString(", ") } else { + sHolder.pushedPredicates = pushedFilters.right.get pushedFilters.right.get.mkString(", ") } @@ -396,8 +398,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - val pushedDownOperators = PushedDownOperators(aggregation, - sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortOrders) + val pushedDownOperators = PushedDownOperators(aggregation, sHolder.pushedSample, + sHolder.pushedLimit, sHolder.sortOrders, sHolder.pushedPredicates) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } @@ -413,6 +415,8 @@ case class ScanBuilderHolder( var sortOrders: Seq[SortOrder] = Seq.empty[SortOrder] var pushedSample: Option[TableSampleInfo] = None + + var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index 87ec9f43804e4..f68f78d51fd96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -19,16 +19,17 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo -import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} +import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType case class JDBCScan( relation: JDBCRelation, prunedSchema: StructType, - pushedFilters: Array[Filter], + pushedPredicates: Array[Predicate], pushedAggregateColumn: Array[String] = Array(), groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], @@ -48,7 +49,7 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, + relation.buildScan(columnList, prunedSchema, pushedPredicates, groupByColumns, tableSample, pushedLimit, sortOrders) } }.asInstanceOf[T] @@ -63,7 +64,7 @@ case class JDBCScan( ("[]", "[]") } super.description() + ", prunedSchema: " + seqToString(prunedSchema) + - ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedPredicates: " + seqToString(pushedPredicates) + ", PushedAggregates: " + aggString + ", PushedGroupBy: " + groupByString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 61bf729bc8fbf..475f563856f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -22,12 +22,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.JdbcDialects -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType case class JDBCScanBuilder( @@ -35,7 +35,7 @@ case class JDBCScanBuilder( schema: StructType, jdbcOptions: JDBCOptions) extends ScanBuilder - with SupportsPushDownFilters + with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns with SupportsPushDownAggregates with SupportsPushDownLimit @@ -45,7 +45,7 @@ case class JDBCScanBuilder( private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis - private var pushedFilter = Array.empty[Filter] + private var pushedPredicate = Array.empty[Predicate] private var finalSchema = schema @@ -55,18 +55,18 @@ case class JDBCScanBuilder( private var sortOrders: Array[SortOrder] = Array.empty[SortOrder] - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) - val (pushed, unSupported) = filters.partition(JDBCRDD.compileFilter(_, dialect).isDefined) - this.pushedFilter = pushed + val (pushed, unSupported) = predicates.partition(dialect.compileExpression(_).isDefined) + this.pushedPredicate = pushed unSupported } else { - filters + predicates } } - override def pushedFilters(): Array[Filter] = pushedFilter + override def pushedPredicates(): Array[Predicate] = pushedPredicate private var pushedAggregateList: Array[String] = Array() @@ -170,7 +170,7 @@ case class JDBCScanBuilder( // "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. - JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, + JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate, pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index dd68953badf7a..9bf25aa0d633f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -35,27 +35,27 @@ private object DB2Dialect extends JdbcDialect { super.compileAggregate(aggFunction).orElse( aggFunction match { case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARIANCE($distinct${f.inputs().head})") + Some(s"VARIANCE($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARIANCE_SAMP($distinct${f.inputs().head})") + Some(s"VARIANCE_SAMP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV($distinct${f.inputs().head})") + Some(s"STDDEV($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + Some(s"STDDEV_SAMP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.inputs().length == 2) - Some(s"COVARIANCE(${f.inputs().head}, ${f.inputs().last})") + assert(f.children().length == 2) + Some(s"COVARIANCE(${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 2) - Some(s"COVARIANCE_SAMP(${f.inputs().head}, ${f.inputs().last})") + assert(f.children().length == 2) + Some(s"COVARIANCE_SAMP(${f.children().head}, ${f.children().last})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index bf838b8ed66eb..36c3c6be4a05c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -35,17 +35,17 @@ private object DerbyDialect extends JdbcDialect { super.compileAggregate(aggFunction).orElse( aggFunction match { case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"VAR_POP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"VAR_POP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"VAR_SAMP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"VAR_SAMP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"STDDEV_POP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"STDDEV_POP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"STDDEV_SAMP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"STDDEV_SAMP(${f.children().head})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 1f422e5a59cf8..5f92f6dae9f11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -32,33 +32,33 @@ private object H2Dialect extends JdbcDialect { super.compileAggregate(aggFunction).orElse( aggFunction match { case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") + Some(s"VAR_POP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") + Some(s"VAR_SAMP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") + Some(s"STDDEV_POP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + Some(s"STDDEV_SAMP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => - assert(f.inputs().length == 2) + assert(f.children().length == 2) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => - assert(f.inputs().length == 2) + assert(f.children().length == 2) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "CORR" => - assert(f.inputs().length == 2) + assert(f.children().length == 2) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"CORR($distinct${f.children().head}, ${f.children().last})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e886d8b8deae7..c4b29fd9c9152 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -22,17 +22,19 @@ import java.time.{Instant, LocalDate} import java.util import scala.collection.mutable.ArrayBuilder +import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors @@ -220,12 +222,18 @@ abstract class JdbcDialect extends Serializable with Logging{ } class JDBCSQLBuilder extends V2ExpressionSQLBuilder { - override def visitFieldReference(fieldRef: FieldReference): String = { - if (fieldRef.fieldNames().length != 1) { + override def visitLiteral(literal: Literal[_]): String = { + compileValue( + CatalystTypeConverters.convertToScala(literal.value(), literal.dataType())).toString + } + + override def visitNamedReference(namedRef: NamedReference): String = { + if (namedRef.fieldNames().length > 1) { throw new IllegalArgumentException( - "FieldReference with field name has multiple or zero parts unsupported: " + fieldRef); + QueryCompilationErrors.commandNotSupportNestedColumnError( + "Filter push down", namedRef.toString).getMessage); } - quoteIdentifier(fieldRef.fieldNames.head) + quoteIdentifier(namedRef.fieldNames.head) } } @@ -240,7 +248,9 @@ abstract class JdbcDialect extends Serializable with Logging{ try { Some(jdbcSQLBuilder.build(expr)) } catch { - case _: IllegalArgumentException => None + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 841f1c87319b5..8d2fbec55f919 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -47,21 +47,21 @@ private object MsSqlServerDialect extends JdbcDialect { super.compileAggregate(aggFunction).orElse( aggFunction match { case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VARP($distinct${f.inputs().head})") + Some(s"VARP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR($distinct${f.inputs().head})") + Some(s"VAR($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDEVP($distinct${f.inputs().head})") + Some(s"STDEVP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDEV($distinct${f.inputs().head})") + Some(s"STDEV($distinct${f.children().head})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index d73721de962d7..24f9bac74f86d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -43,17 +43,17 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { super.compileAggregate(aggFunction).orElse( aggFunction match { case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"VAR_POP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"VAR_POP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"VAR_SAMP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"VAR_SAMP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"STDDEV_POP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"STDDEV_POP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"STDDEV_SAMP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"STDDEV_SAMP(${f.children().head})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 71db7e9285f5e..40333c1757c4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -41,26 +41,26 @@ private case object OracleDialect extends JdbcDialect { super.compileAggregate(aggFunction).orElse( aggFunction match { case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"VAR_POP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"VAR_POP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"VAR_SAMP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"VAR_SAMP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"STDDEV_POP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"STDDEV_POP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 1) - Some(s"STDDEV_SAMP(${f.inputs().head})") + assert(f.children().length == 1) + Some(s"STDDEV_SAMP(${f.children().head})") case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.inputs().length == 2) - Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})") + assert(f.children().length == 2) + Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 2) - Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})") + assert(f.children().length == 2) + Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => - assert(f.inputs().length == 2) - Some(s"CORR(${f.inputs().head}, ${f.inputs().last})") + assert(f.children().length == 2) + Some(s"CORR(${f.children().head}, ${f.children().last})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index e2023d110ae4b..a668d66ee2f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -41,33 +41,33 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { super.compileAggregate(aggFunction).orElse( aggFunction match { case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") + Some(s"VAR_POP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") + Some(s"VAR_SAMP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") + Some(s"STDDEV_POP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + Some(s"STDDEV_SAMP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => - assert(f.inputs().length == 2) + assert(f.children().length == 2) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => - assert(f.inputs().length == 2) + assert(f.children().length == 2) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "CORR" => - assert(f.inputs().length == 2) + assert(f.children().length == 2) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"CORR($distinct${f.children().head}, ${f.children().last})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 13e16d24d048d..79fb710cf03b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -35,30 +35,30 @@ private case object TeradataDialect extends JdbcDialect { super.compileAggregate(aggFunction).orElse( aggFunction match { case f: GeneralAggregateFunc if f.name() == "VAR_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") + Some(s"VAR_POP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") + Some(s"VAR_SAMP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") + Some(s"STDDEV_POP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => - assert(f.inputs().length == 1) + assert(f.children().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + Some(s"STDDEV_SAMP($distinct${f.children().head})") case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => - assert(f.inputs().length == 2) - Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})") + assert(f.children().length == 2) + Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => - assert(f.inputs().length == 2) - Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})") + assert(f.children().length == 2) + Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => - assert(f.inputs().length == 2) - Some(s"CORR(${f.inputs().head}, ${f.inputs().last})") + assert(f.children().length == 2) + Some(s"CORR(${f.children().head}, ${f.children().last})") case _ => None } ) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java index b92206c6a5444..ec532da61042f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java @@ -17,21 +17,23 @@ package test.org.apache.spark.sql.connector; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.expressions.filter.Filter; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.*; -import org.apache.spark.sql.connector.expressions.filter.GreaterThan; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - public class JavaAdvancedDataSourceV2WithV2Filter implements TestingV2Source { @Override @@ -48,7 +50,7 @@ static class AdvancedScanBuilderWithV2Filter implements ScanBuilder, Scan, SupportsPushDownV2Filters, SupportsPushDownRequiredColumns { private StructType requiredSchema = TestingV2Source.schema(); - private Filter[] filters = new Filter[0]; + private Predicate[] predicates = new Predicate[0]; @Override public void pruneColumns(StructType requiredSchema) { @@ -61,32 +63,38 @@ public StructType readSchema() { } @Override - public Filter[] pushFilters(Filter[] filters) { - Filter[] supported = Arrays.stream(filters).filter(f -> { - if (f instanceof GreaterThan) { - GreaterThan gt = (GreaterThan) f; - return gt.column().describe().equals("i") && gt.value().value() instanceof Integer; + public Predicate[] pushPredicates(Predicate[] predicates) { + Predicate[] supported = Arrays.stream(predicates).filter(f -> { + if (f.name().equals(">")) { + assert(f.children()[0] instanceof FieldReference); + FieldReference column = (FieldReference) f.children()[0]; + assert(f.children()[1] instanceof LiteralValue); + Literal value = (Literal) f.children()[1]; + return column.describe().equals("i") && value.value() instanceof Integer; } else { return false; } - }).toArray(Filter[]::new); - - Filter[] unsupported = Arrays.stream(filters).filter(f -> { - if (f instanceof GreaterThan) { - GreaterThan gt = (GreaterThan) f; - return !gt.column().describe().equals("i") || !(gt.value().value() instanceof Integer); + }).toArray(Predicate[]::new); + + Predicate[] unsupported = Arrays.stream(predicates).filter(f -> { + if (f.name().equals(">")) { + assert(f.children()[0] instanceof FieldReference); + FieldReference column = (FieldReference) f.children()[0]; + assert(f.children()[1] instanceof LiteralValue); + Literal value = (LiteralValue) f.children()[1]; + return !column.describe().equals("i") || !(value.value() instanceof Integer); } else { return true; } - }).toArray(Filter[]::new); + }).toArray(Predicate[]::new); - this.filters = supported; + this.predicates = supported; return unsupported; } @Override - public Filter[] pushedFilters() { - return filters; + public Predicate[] pushedPredicates() { + return predicates; } @Override @@ -96,18 +104,18 @@ public Scan build() { @Override public Batch toBatch() { - return new AdvancedBatchWithV2Filter(requiredSchema, filters); + return new AdvancedBatchWithV2Filter(requiredSchema, predicates); } } public static class AdvancedBatchWithV2Filter implements Batch { // Exposed for testing. public StructType requiredSchema; - public Filter[] filters; + public Predicate[] predicates; - AdvancedBatchWithV2Filter(StructType requiredSchema, Filter[] filters) { + AdvancedBatchWithV2Filter(StructType requiredSchema, Predicate[] predicates) { this.requiredSchema = requiredSchema; - this.filters = filters; + this.predicates = predicates; } @Override @@ -115,11 +123,14 @@ public InputPartition[] planInputPartitions() { List res = new ArrayList<>(); Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.column().describe()) && f.value().value() instanceof Integer) { - lowerBound = (Integer) f.value().value(); + for (Predicate predicate : predicates) { + if (predicate.name().equals(">")) { + assert(predicate.children()[0] instanceof FieldReference); + FieldReference column = (FieldReference) predicate.children()[0]; + assert(predicate.children()[1] instanceof LiteralValue); + Literal value = (Literal) predicate.children()[1]; + if ("i".equals(column.describe()) && value.value() instanceof Integer) { + lowerBound = (Integer) value.value(); break; } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 1f19836834171..cff58d7367317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter, GreaterThan => V2GreaterThan} +import org.apache.spark.sql.connector.expressions.{Literal, Transform} +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -155,11 +155,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { val batch = getBatchWithV2Filter(q1) - assert(batch.filters.isEmpty) + assert(batch.predicates.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } else { val batch = getJavaBatchWithV2Filter(q1) - assert(batch.filters.isEmpty) + assert(batch.predicates.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } @@ -167,11 +167,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { val batch = getBatchWithV2Filter(q2) - assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } else { val batch = getJavaBatchWithV2Filter(q2) - assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } @@ -179,11 +179,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { val batch = getBatchWithV2Filter(q3) - assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) assert(batch.requiredSchema.fieldNames === Seq("i")) } else { val batch = getJavaBatchWithV2Filter(q3) - assert(batch.filters.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) assert(batch.requiredSchema.fieldNames === Seq("i")) } @@ -192,12 +192,12 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { val batch = getBatchWithV2Filter(q4) // 'j < 10 is not supported by the testing data source. - assert(batch.filters.isEmpty) + assert(batch.predicates.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } else { val batch = getJavaBatchWithV2Filter(q4) // 'j < 10 is not supported by the testing data source. - assert(batch.filters.isEmpty) + assert(batch.predicates.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } } @@ -683,7 +683,7 @@ class AdvancedScanBuilderWithV2Filter extends ScanBuilder with Scan with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns { var requiredSchema = TestingV2Source.schema - var filters = Array.empty[V2Filter] + var predicates = Array.empty[Predicate] override def pruneColumns(requiredSchema: StructType): Unit = { this.requiredSchema = requiredSchema @@ -691,29 +691,32 @@ class AdvancedScanBuilderWithV2Filter extends ScanBuilder override def readSchema(): StructType = requiredSchema - override def pushFilters(filters: Array[V2Filter]): Array[V2Filter] = { - val (supported, unsupported) = filters.partition { - case _: V2GreaterThan => true + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + val (supported, unsupported) = predicates.partition { + case p: Predicate if p.name() == ">" => true case _ => false } - this.filters = supported + this.predicates = supported unsupported } - override def pushedFilters(): Array[V2Filter] = filters + override def pushedPredicates(): Array[Predicate] = predicates override def build(): Scan = this - override def toBatch: Batch = new AdvancedBatchWithV2Filter(filters, requiredSchema) + override def toBatch: Batch = new AdvancedBatchWithV2Filter(predicates, requiredSchema) } class AdvancedBatchWithV2Filter( - val filters: Array[V2Filter], + val predicates: Array[Predicate], val requiredSchema: StructType) extends Batch { override def planInputPartitions(): Array[InputPartition] = { - val lowerBound = filters.collectFirst { - case gt: V2GreaterThan => gt.value + val lowerBound = predicates.collectFirst { + case p: Predicate if p.name().equals(">") => + val value = p.children()(1) + assert(value.isInstanceOf[Literal[_]]) + value.asInstanceOf[Literal[_]] } val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala new file mode 100644 index 0000000000000..6296da47cca51 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.BooleanType + +class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { + test("SPARK-36644: Push down boolean column filter") { + testTranslateFilter(Symbol("col").boolean, + Some(new Predicate("=", Array(FieldReference("col"), LiteralValue(true, BooleanType))))) + } + + /** + * Translate the given Catalyst [[Expression]] into data source V2 [[Predicate]] + * then verify against the given [[Predicate]]. + */ + def testTranslateFilter(catalystFilter: Expression, result: Option[Predicate]): Unit = { + assertResult(result) { + DataSourceV2Strategy.translateFilterV2(catalystFilter, true) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala deleted file mode 100644 index b457211b7f89f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.connector.expressions.{FieldReference, Literal, LiteralValue} -import org.apache.spark.sql.connector.expressions.filter._ -import org.apache.spark.sql.execution.datasources.v2.FiltersV2Suite.ref -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.unsafe.types.UTF8String - -class FiltersV2Suite extends SparkFunSuite { - - test("nested columns") { - val filter1 = new EqualTo(ref("a", "B"), LiteralValue(1, IntegerType)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a.B")) - assert(filter1.describe.equals("a.B = 1")) - - val filter2 = new EqualTo(ref("a", "b.c"), LiteralValue(1, IntegerType)) - assert(filter2.references.map(_.describe()).toSeq == Seq("a.`b.c`")) - assert(filter2.describe.equals("a.`b.c` = 1")) - - val filter3 = new EqualTo(ref("`a`.b", "c"), LiteralValue(1, IntegerType)) - assert(filter3.references.map(_.describe()).toSeq == Seq("```a``.b`.c")) - assert(filter3.describe.equals("```a``.b`.c = 1")) - } - - test("AlwaysTrue") { - val filter1 = new AlwaysTrue - val filter2 = new AlwaysTrue - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).length == 0) - assert(filter1.describe.equals("TRUE")) - } - - test("AlwaysFalse") { - val filter1 = new AlwaysFalse - val filter2 = new AlwaysFalse - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).length == 0) - assert(filter1.describe.equals("FALSE")) - } - - test("EqualTo") { - val filter1 = new EqualTo(ref("a"), LiteralValue(1, IntegerType)) - val filter2 = new EqualTo(ref("a"), LiteralValue(1, IntegerType)) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a = 1")) - } - - test("EqualNullSafe") { - val filter1 = new EqualNullSafe(ref("a"), LiteralValue(1, IntegerType)) - val filter2 = new EqualNullSafe(ref("a"), LiteralValue(1, IntegerType)) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a <=> 1")) - } - - test("GreaterThan") { - val filter1 = new GreaterThan(ref("a"), LiteralValue(1, IntegerType)) - val filter2 = new GreaterThan(ref("a"), LiteralValue(1, IntegerType)) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a > 1")) - } - - test("GreaterThanOrEqual") { - val filter1 = new GreaterThanOrEqual(ref("a"), LiteralValue(1, IntegerType)) - val filter2 = new GreaterThanOrEqual(ref("a"), LiteralValue(1, IntegerType)) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a >= 1")) - } - - test("LessThan") { - val filter1 = new LessThan(ref("a"), LiteralValue(1, IntegerType)) - val filter2 = new LessThan(ref("a"), LiteralValue(1, IntegerType)) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a < 1")) - } - - test("LessThanOrEqual") { - val filter1 = new LessThanOrEqual(ref("a"), LiteralValue(1, IntegerType)) - val filter2 = new LessThanOrEqual(ref("a"), LiteralValue(1, IntegerType)) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a <= 1")) - } - - test("In") { - val filter1 = new In(ref("a"), - Array(LiteralValue(1, IntegerType), LiteralValue(2, IntegerType), - LiteralValue(3, IntegerType), LiteralValue(4, IntegerType))) - val filter2 = new In(ref("a"), - Array(LiteralValue(4, IntegerType), LiteralValue(2, IntegerType), - LiteralValue(3, IntegerType), LiteralValue(1, IntegerType))) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a IN (1, 2, 3, 4)")) - val values: Array[Literal[_]] = new Array[Literal[_]](1000) - for (i <- 0 until 1000) { - values(i) = LiteralValue(i, IntegerType) - } - val filter3 = new In(ref("a"), values) - var expected = "a IN (" - for (i <- 0 until 50) { - expected += i + ", " - } - expected = expected.dropRight(2) // remove the last ", " - expected += "...)" - assert(filter3.describe.equals(expected)) - } - - test("IsNull") { - val filter1 = new IsNull(ref("a")) - val filter2 = new IsNull(ref("a")) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a IS NULL")) - } - - test("IsNotNull") { - val filter1 = new IsNotNull(ref("a")) - val filter2 = new IsNotNull(ref("a")) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("a IS NOT NULL")) - } - - test("Not") { - val filter1 = new Not(new LessThan(ref("a"), LiteralValue(1, IntegerType))) - val filter2 = new Not(new LessThan(ref("a"), LiteralValue(1, IntegerType))) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("NOT (a < 1)")) - } - - test("And") { - val filter1 = new And(new EqualTo(ref("a"), LiteralValue(1, IntegerType)), - new EqualTo(ref("b"), LiteralValue(1, IntegerType))) - val filter2 = new And(new EqualTo(ref("a"), LiteralValue(1, IntegerType)), - new EqualTo(ref("b"), LiteralValue(1, IntegerType))) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a", "b")) - assert(filter1.describe.equals("(a = 1) AND (b = 1)")) - } - - test("Or") { - val filter1 = new Or(new EqualTo(ref("a"), LiteralValue(1, IntegerType)), - new EqualTo(ref("b"), LiteralValue(1, IntegerType))) - val filter2 = new Or(new EqualTo(ref("a"), LiteralValue(1, IntegerType)), - new EqualTo(ref("b"), LiteralValue(1, IntegerType))) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a", "b")) - assert(filter1.describe.equals("(a = 1) OR (b = 1)")) - } - - test("StringStartsWith") { - val filter1 = new StringStartsWith(ref("a"), UTF8String.fromString("str")) - val filter2 = new StringStartsWith(ref("a"), UTF8String.fromString("str")) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("STRING_STARTS_WITH(a, str)")) - } - - test("StringEndsWith") { - val filter1 = new StringEndsWith(ref("a"), UTF8String.fromString("str")) - val filter2 = new StringEndsWith(ref("a"), UTF8String.fromString("str")) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("STRING_ENDS_WITH(a, str)")) - } - - test("StringContains") { - val filter1 = new StringContains(ref("a"), UTF8String.fromString("str")) - val filter2 = new StringContains(ref("a"), UTF8String.fromString("str")) - assert(filter1.equals(filter2)) - assert(filter1.references.map(_.describe()).toSeq == Seq("a")) - assert(filter1.describe.equals("STRING_CONTAINS(a, str)")) - } -} - -object FiltersV2Suite { - private[sql] def ref(parts: String*): FieldReference = { - new FieldReference(parts) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala new file mode 100644 index 0000000000000..2d6e6fcf16174 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter._ +import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref +import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class V2PredicateSuite extends SparkFunSuite { + + test("nested columns") { + val predicate1 = + new Predicate("=", Array[Expression](ref("a", "B"), LiteralValue(1, IntegerType))) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a.B")) + assert(predicate1.describe.equals("a.B = 1")) + + val predicate2 = + new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, IntegerType))) + assert(predicate2.references.map(_.describe()).toSeq == Seq("a.`b.c`")) + assert(predicate2.describe.equals("a.`b.c` = 1")) + + val predicate3 = + new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, IntegerType))) + assert(predicate3.references.map(_.describe()).toSeq == Seq("```a``.b`.c")) + assert(predicate3.describe.equals("```a``.b`.c = 1")) + } + + test("AlwaysTrue") { + val predicate1 = new AlwaysTrue + val predicate2 = new AlwaysTrue + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).length == 0) + assert(predicate1.describe.equals("TRUE")) + } + + test("AlwaysFalse") { + val predicate1 = new AlwaysFalse + val predicate2 = new AlwaysFalse + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).length == 0) + assert(predicate1.describe.equals("FALSE")) + } + + test("EqualTo") { + val predicate1 = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + val predicate2 = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a = 1")) + } + + test("EqualNullSafe") { + val predicate1 = new Predicate("<=>", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + val predicate2 = new Predicate("<=>", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("(a = 1) OR (a IS NULL AND 1 IS NULL)")) + } + + test("In") { + val predicate1 = new Predicate("IN", + Array(ref("a"), LiteralValue(1, IntegerType), LiteralValue(2, IntegerType), + LiteralValue(3, IntegerType), LiteralValue(4, IntegerType))) + val predicate2 = new Predicate("IN", + Array(ref("a"), LiteralValue(4, IntegerType), LiteralValue(2, IntegerType), + LiteralValue(3, IntegerType), LiteralValue(1, IntegerType))) + assert(!predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a IN (1, 2, 3, 4)")) + val values: Array[Literal[_]] = new Array[Literal[_]](1000) + var expected = "a IN (" + for (i <- 0 until 1000) { + values(i) = LiteralValue(i, IntegerType) + expected += i + ", " + } + val predicate3 = new Predicate("IN", (ref("a") +: values).toArray[Expression]) + expected = expected.dropRight(2) // remove the last ", " + expected += ")" + assert(predicate3.describe.equals(expected)) + } + + test("IsNull") { + val predicate1 = new Predicate("IS_NULL", Array[Expression](ref("a"))) + val predicate2 = new Predicate("IS_NULL", Array[Expression](ref("a"))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a IS NULL")) + } + + test("IsNotNull") { + val predicate1 = new Predicate("IS_NOT_NULL", Array[Expression](ref("a"))) + val predicate2 = new Predicate("IS_NOT_NULL", Array[Expression](ref("a"))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a IS NOT NULL")) + } + + test("Not") { + val predicate1 = new Not( + new Predicate("<", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))) + val predicate2 = new Not( + new Predicate("<", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("NOT (a < 1)")) + } + + test("And") { + val predicate1 = new And( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) + val predicate2 = new And( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b")) + assert(predicate1.describe.equals("(a = 1) AND (b = 1)")) + } + + test("Or") { + val predicate1 = new Or( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) + val predicate2 = new Or( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b")) + assert(predicate1.describe.equals("(a = 1) OR (b = 1)")) + } + + test("StringStartsWith") { + val literal = LiteralValue(UTF8String.fromString("str"), StringType) + val predicate1 = new Predicate("STARTS_WITH", + Array[Expression](ref("a"), literal)) + val predicate2 = new Predicate("STARTS_WITH", + Array[Expression](ref("a"), literal)) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a LIKE 'str%'")) + } + + test("StringEndsWith") { + val literal = LiteralValue(UTF8String.fromString("str"), StringType) + val predicate1 = new Predicate("ENDS_WITH", + Array[Expression](ref("a"), literal)) + val predicate2 = new Predicate("ENDS_WITH", + Array[Expression](ref("a"), literal)) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a LIKE '%str'")) + } + + test("StringContains") { + val literal = LiteralValue(UTF8String.fromString("str"), StringType) + val predicate1 = new Predicate("CONTAINS", + Array[Expression](ref("a"), literal)) + val predicate2 = new Predicate("CONTAINS", + Array[Expression](ref("a"), literal)) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a LIKE '%str%'")) + } +} + +object V2PredicateSuite { + private[sql] def ref(parts: String*): FieldReference = { + new FieldReference(parts) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3cb91b8b00190..8f690eeaff901 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeTestUtils import org.apache.spark.sql.execution.{DataSourceScanExec, ExtendedMode} import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCommand} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ @@ -773,33 +773,36 @@ class JDBCSuite extends QueryTest } test("compile filters") { - val compileFilter = PrivateMethod[Option[String]](Symbol("compileFilter")) def doCompileFilter(f: Filter): String = - JDBCRDD invokePrivate compileFilter(f, JdbcDialects.get("jdbc:")) getOrElse("") - assert(doCompileFilter(EqualTo("col0", 3)) === """"col0" = 3""") - assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === """(NOT ("col1" = 'abc'))""") - assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) - === """("col0" = 0) AND ("col1" = 'def')""") - assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi"))) - === """("col0" = 2) OR ("col1" = 'ghi')""") - assert(doCompileFilter(LessThan("col0", 5)) === """"col0" < 5""") - assert(doCompileFilter(LessThan("col3", - Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col3" < '1995-11-21 00:00:00.0'""") - assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) - === """"col4" < '1983-08-04'""") - assert(doCompileFilter(LessThanOrEqual("col0", 5)) === """"col0" <= 5""") - assert(doCompileFilter(GreaterThan("col0", 3)) === """"col0" > 3""") - assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === """"col0" >= 3""") - assert(doCompileFilter(In("col1", Array("jkl"))) === """"col1" IN ('jkl')""") - assert(doCompileFilter(In("col1", Array.empty)) === - """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""") - assert(doCompileFilter(Not(In("col1", Array("mno", "pqr")))) - === """(NOT ("col1" IN ('mno', 'pqr')))""") - assert(doCompileFilter(IsNull("col1")) === """"col1" IS NULL""") - assert(doCompileFilter(IsNotNull("col1")) === """"col1" IS NOT NULL""") - assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) - === """((NOT ("col0" != 'abc' OR "col0" IS NULL OR 'abc' IS NULL) """ - + """OR ("col0" IS NULL AND 'abc' IS NULL))) AND ("col1" = 'def')""") + JdbcDialects.get("jdbc:").compileExpression(f.toV2).getOrElse("") + + Seq(("col0", "col1"), ("`col0`", "`col1`")).foreach { case(col0, col1) => + assert(doCompileFilter(EqualTo(col0, 3)) === """"col0" = 3""") + assert(doCompileFilter(Not(EqualTo(col1, "abc"))) === """NOT ("col1" = 'abc')""") + assert(doCompileFilter(And(EqualTo(col0, 0), EqualTo(col1, "def"))) + === """("col0" = 0) AND ("col1" = 'def')""") + assert(doCompileFilter(Or(EqualTo(col0, 2), EqualTo(col1, "ghi"))) + === """("col0" = 2) OR ("col1" = 'ghi')""") + assert(doCompileFilter(LessThan(col0, 5)) === """"col0" < 5""") + assert(doCompileFilter(LessThan(col0, + Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col0" < '1995-11-21 00:00:00.0'""") + assert(doCompileFilter(LessThan(col0, Date.valueOf("1983-08-04"))) + === """"col0" < '1983-08-04'""") + assert(doCompileFilter(LessThanOrEqual(col0, 5)) === """"col0" <= 5""") + assert(doCompileFilter(GreaterThan(col0, 3)) === """"col0" > 3""") + assert(doCompileFilter(GreaterThanOrEqual(col0, 3)) === """"col0" >= 3""") + assert(doCompileFilter(In(col1, Array("jkl"))) === """"col1" IN ('jkl')""") + assert(doCompileFilter(In(col1, Array.empty)) === + """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""") + assert(doCompileFilter(Not(In(col1, Array("mno", "pqr")))) + === """NOT ("col1" IN ('mno', 'pqr'))""") + assert(doCompileFilter(IsNull(col1)) === """"col1" IS NULL""") + assert(doCompileFilter(IsNotNull(col1)) === """"col1" IS NOT NULL""") + assert(doCompileFilter(And(EqualNullSafe(col0, "abc"), EqualTo(col1, "def"))) + === """(("col0" = 'abc') OR ("col0" IS NULL AND 'abc' IS NULL))""" + + """ AND ("col1" = 'def')""") + } + assert(doCompileFilter(EqualTo("col0.nested", 3)).isEmpty) } test("Dialect unregister") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 6d0acdc700723..23f28c2f9c3d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, sum, udf} +import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, sum, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -70,17 +70,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate() conn.prepareStatement( "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + - " bonus DOUBLE)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") - .executeUpdate() + " bonus DOUBLE, is_manager BOOLEAN)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000, true)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200, false)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200, false)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300, true)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200, true)").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate() + + // scalastyle:off + conn.prepareStatement( + "CREATE TABLE \"test\".\"person\" (\"名\" INTEGER NOT NULL)").executeUpdate() + // scalastyle:on + conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate() + conn.prepareStatement( + """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate() + conn.prepareStatement( + """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate() } } @@ -112,7 +127,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) checkPushedLimit(df1, Some(1)) - checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0))) + checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df2 = spark.read .option("partitionColumn", "dept") @@ -123,7 +138,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .limit(1) checkPushedLimit(df2, Some(1)) - checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) + checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") val scan = df3.queryExecution.optimizedPlan.collectFirst { @@ -175,12 +190,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort("salary") .limit(1) checkPushedLimit(df1, Some(1), createSortValues()) - checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df2 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary").limit(1) checkPushedLimit(df2, Some(1), createSortValues()) - checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0))) + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df3 = spark.read .option("partitionColumn", "dept") @@ -193,7 +208,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .limit(1) checkPushedLimit( df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) - checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0))) + checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df4 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") @@ -207,7 +222,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df5 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary") checkPushedLimit(df5, None) - checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) + checkAnswer(df5, + Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true))) val df6 = spark.read .table("h2.test.employee") @@ -234,7 +250,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort(sub($"NAME")) .limit(1) checkPushedLimit(df8) - checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0))) + checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) } private def createSortValues( @@ -253,11 +269,165 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]" + "PushedFilters: [ID IS NOT NULL, ID > 1]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Row("mary", 2)) + + val df2 = spark.table("h2.test.employee").filter($"name".isin("amy", "cathy")) + + checkFiltersRemoved(df2) + + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [NAME IN ('amy', 'cathy')]" + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + + checkAnswer(df2, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false))) + + val df3 = spark.table("h2.test.employee").filter($"name".startsWith("a")) + + checkFiltersRemoved(df3) + + df3.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%']" + checkKeywordsExistsInExplain(df3, expected_plan_fragment) + } + + checkAnswer(df3, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false))) + + val df4 = spark.table("h2.test.employee").filter($"is_manager") + + checkFiltersRemoved(df4) + + df4.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true]" + checkKeywordsExistsInExplain(df4, expected_plan_fragment) + } + + checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "david", 10000, 1300, true), + Row(6, "jen", 12000, 1200, true))) + + val df5 = spark.table("h2.test.employee").filter($"is_manager".and($"salary" > 10000)) + + checkFiltersRemoved(df5) + + df5.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IS_MANAGER IS NOT NULL, SALARY IS NOT NULL, " + + "IS_MANAGER = true, SALARY > 10000.00]" + checkKeywordsExistsInExplain(df5, expected_plan_fragment) + } + + checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true))) + + val df6 = spark.table("h2.test.employee").filter($"is_manager".or($"salary" > 10000)) + + checkFiltersRemoved(df6) + + df6.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [(IS_MANAGER = true) OR (SALARY > 10000.00)], " + checkKeywordsExistsInExplain(df6, expected_plan_fragment) + } + + checkAnswer(df6, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false), + Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) + + val df7 = spark.table("h2.test.employee").filter(not($"is_manager") === true) + + checkFiltersRemoved(df7) + + df7.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IS_MANAGER IS NOT NULL, NOT (IS_MANAGER = true)], " + checkKeywordsExistsInExplain(df7, expected_plan_fragment) + } + + checkAnswer(df7, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "alex", 12000, 1200, false))) + + val df8 = spark.table("h2.test.employee").filter($"is_manager" === true) + + checkFiltersRemoved(df8) + + df8.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true], " + checkKeywordsExistsInExplain(df8, expected_plan_fragment) + } + + checkAnswer(df8, Seq(Row(1, "amy", 10000, 1000, true), + Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) + + val df9 = spark.table("h2.test.employee") + .filter(when($"dept" > 1, true).when($"is_manager", false).otherwise($"dept" > 3)) + + checkFiltersRemoved(df9) + + df9.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedFilters: [CASE WHEN DEPT > 1 THEN TRUE WHEN IS_MANAGER = true THEN FALSE" + + " ELSE DEPT > 3 END], " + checkKeywordsExistsInExplain(df9, expected_plan_fragment) + } + + checkAnswer(df9, Seq(Row(2, "alex", 12000, 1200, false), + Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) + } + + test("scan with complex filter push-down") { + Seq(false, true).foreach { ansiMode => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val df = spark.table("h2.test.people").filter($"id" + 1 > 1) + + checkFiltersRemoved(df, ansiMode) + + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = if (ansiMode) { + "PushedFilters: [ID IS NOT NULL, (ID + 1) > 1]" + } else { + "PushedFilters: [ID IS NOT NULL]" + } + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + + checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2))) + + val df2 = sql(""" + |SELECT * FROM h2.test.employee + |WHERE (CASE WHEN SALARY > 10000 THEN BONUS ELSE BONUS + 200 END) > 1200 + |""".stripMargin) + + checkFiltersRemoved(df2, ansiMode) + + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = if (ansiMode) { + "PushedFilters: [(CASE WHEN SALARY > 10000.00 THEN BONUS" + + " ELSE BONUS + 200.0 END) > 1200.0]" + } else { + "PushedFilters: []" + } + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + + checkAnswer(df2, + Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + } + } } test("scan with column pruning") { @@ -412,18 +582,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } - private def checkFiltersRemoved(df: DataFrame): Unit = { + private def checkFiltersRemoved(df: DataFrame, removed: Boolean = true): Unit = { val filters = df.queryExecution.optimizedPlan.collect { case f: Filter => f } - assert(filters.isEmpty) + if (removed) { + assert(filters.isEmpty) + } else { + assert(filters.nonEmpty) + } } test("scan with aggregate push-down: MAX AVG with filter without group by") { @@ -437,8 +611,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [MAX(ID), AVG(ID)], " + - "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + - "PushedGroupByColumns: []" + "PushedFilters: [ID IS NOT NULL, ID > 0], " + + "PushedGroupByColumns: [], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(2, 1.5))) @@ -553,7 +727,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [SUM(SALARY)], " + "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT]" + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) @@ -567,7 +741,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [SUM(DISTINCT SALARY)], " + "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT]" + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) @@ -585,8 +759,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT, NAME]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT, NAME], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), @@ -605,8 +779,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [MAX(SALARY)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT, NAME]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT, NAME], " checkKeywordsExistsInExplain(df1, expected_plan_fragment) } checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), @@ -623,8 +797,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT, NAME]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT, NAME], " checkKeywordsExistsInExplain(df2, expected_plan_fragment) } checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), @@ -640,7 +814,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df3.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " checkKeywordsExistsInExplain(df3, expected_plan_fragment) } checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), @@ -659,8 +833,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) @@ -676,7 +850,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val expected_plan_fragment = "PushedAggregates: [MIN(SALARY)], " + "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT]" + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) @@ -699,8 +873,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) @@ -742,8 +916,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) @@ -758,8 +932,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) @@ -774,8 +948,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) @@ -790,15 +964,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [CORR(BONUS, BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupByColumns: [DEPT]" + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } test("scan with aggregate push-down: aggregate over alias NOT push down") { - val cols = Seq("a", "b", "c", "d") + val cols = Seq("a", "b", "c", "d", "e") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") checkAggregateRemoved(df2, false) @@ -854,10 +1028,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregates: [COUNT(CASE WHEN ((SALARY) > (8000.00)) AND ((SALARY) < (10000.00))" + - " THEN SALARY ELSE 0.00 END), C..., " + + "PushedAggregates: [COUNT(CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00)" + + " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " + "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT]" + "PushedGroupByColumns: [DEPT], " checkKeywordsExistsInExplain(df, expected_plan_fragment) } checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 2, 0d), @@ -871,8 +1045,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") checkAggregateRemoved(df, ansiMode) val expected_plan_fragment = if (ansiMode) { - "PushedAggregates: [SUM((2147483647) + (DEPT))], " + - "PushedFilters: [], PushedGroupByColumns: []" + "PushedAggregates: [SUM(2147483647 + DEPT)], " + + "PushedFilters: [], " + + "PushedGroupByColumns: []" } else { "PushedFilters: []" } From 776a2b54152356508eb04bb2ea2bd6e4bb558127 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 23 Mar 2022 09:47:35 +0800 Subject: [PATCH 39/53] [SPARK-38432][SQL][FOLLOWUP] Supplement test case for overflow and add comments ### What changes were proposed in this pull request? This PR follows up https://github.com/apache/spark/pull/35768 and improves the code. 1. Supplement test case for overflow 2. Not throw IllegalArgumentException 3. Improve V2ExpressionSQLBuilder 4. Add comments in V2ExpressionBuilder ### Why are the changes needed? Supplement test case for overflow and add comments. ### Does this PR introduce _any_ user-facing change? 'No'. V2 aggregate pushdown not released yet. ### How was this patch tested? New tests. Closes #35933 from beliefer/SPARK-38432_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../util/V2ExpressionSQLBuilder.java | 6 ++-- .../catalyst/util/V2ExpressionBuilder.scala | 2 ++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 5 ++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 34 ++++++++++++++++--- 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 91dae749f974b..1df01d29cbdd1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -75,7 +75,7 @@ public String build(Expression expr) { name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); case "-": if (e.children().length == 1) { - return visitUnaryArithmetic(name, build(e.children()[0])); + return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); } else { return visitBinaryArithmetic( name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); @@ -87,7 +87,7 @@ public String build(Expression expr) { case "NOT": return visitNot(build(e.children()[0])); case "~": - return visitUnaryArithmetic(name, build(e.children()[0])); + return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); case "CASE_WHEN": { List children = Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); @@ -179,7 +179,7 @@ protected String visitNot(String v) { return "NOT (" + v + ")"; } - protected String visitUnaryArithmetic(String name, String v) { return name +" (" + v + ")"; } + protected String visitUnaryArithmetic(String name, String v) { return name + v; } protected String visitCaseWhen(String[] children) { StringBuilder sb = new StringBuilder("CASE"); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index a04e6470f6bf0..392314d473166 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -95,6 +95,7 @@ class V2ExpressionBuilder( None } case and: And => + // AND expects predicate val l = generateExpression(and.left, true) val r = generateExpression(and.right, true) if (l.isDefined && r.isDefined) { @@ -104,6 +105,7 @@ class V2ExpressionBuilder( None } case or: Or => + // OR expects predicate val l = generateExpression(or.left, true) val r = generateExpression(or.right, true) if (l.isDefined && r.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c4b29fd9c9152..f9f90d8fb52b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -229,9 +229,8 @@ abstract class JdbcDialect extends Serializable with Logging{ override def visitNamedReference(namedRef: NamedReference): String = { if (namedRef.fieldNames().length > 1) { - throw new IllegalArgumentException( - QueryCompilationErrors.commandNotSupportNestedColumnError( - "Filter push down", namedRef.toString).getMessage); + throw QueryCompilationErrors.commandNotSupportNestedColumnError( + "Filter push down", namedRef.toString) } quoteIdentifier(namedRef.fieldNames.head) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 23f28c2f9c3d1..f42fadf08dcc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -406,14 +406,38 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2))) - val df2 = sql(""" + val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1) + + checkFiltersRemoved(df2, ansiMode) + + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = if (ansiMode) { + "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], " + } else { + "PushedFilters: [ID IS NOT NULL], " + } + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + + if (ansiMode) { + val e = intercept[SparkException] { + checkAnswer(df2, Seq.empty) + } + assert(e.getMessage.contains( + "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\"")) + } else { + checkAnswer(df2, Seq.empty) + } + + val df3 = sql(""" |SELECT * FROM h2.test.employee |WHERE (CASE WHEN SALARY > 10000 THEN BONUS ELSE BONUS + 200 END) > 1200 |""".stripMargin) - checkFiltersRemoved(df2, ansiMode) + checkFiltersRemoved(df3, ansiMode) - df2.queryExecution.optimizedPlan.collect { + df3.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = if (ansiMode) { "PushedFilters: [(CASE WHEN SALARY > 10000.00 THEN BONUS" + @@ -421,10 +445,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } else { "PushedFilters: []" } - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + checkKeywordsExistsInExplain(df3, expected_plan_fragment) } - checkAnswer(df2, + checkAnswer(df3, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) } } From 61a6d34b57f7feebb544f222c11708617300393e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 23 Mar 2022 15:22:48 +0800 Subject: [PATCH 40/53] [SPARK-38533][SQL] DS V2 aggregate push-down supports project with alias ### What changes were proposed in this pull request? Currently, Spark DS V2 aggregate push-down doesn't supports project with alias. Refer https://github.com/apache/spark/blob/c91c2e9afec0d5d5bbbd2e155057fe409c5bb928/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala#L96 This PR let it works good with alias. **The first example:** the origin plan show below: ``` Aggregate [DEPT#0], [DEPT#0, sum(mySalary#8) AS total#14] +- Project [DEPT#0, SALARY#2 AS mySalary#8] +- ScanBuilderHolder [DEPT#0, NAME#1, SALARY#2, BONUS#3], RelationV2[DEPT#0, NAME#1, SALARY#2, BONUS#3] test.employee, JDBCScanBuilder(org.apache.spark.sql.test.TestSparkSession77978658,StructType(StructField(DEPT,IntegerType,true),StructField(NAME,StringType,true),StructField(SALARY,DecimalType(20,2),true),StructField(BONUS,DoubleType,true)),org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions5f8da82) ``` If we can complete push down the aggregate, then the plan will be: ``` Project [DEPT#0, SUM(SALARY)#18 AS sum(SALARY#2)#13 AS total#14] +- RelationV2[DEPT#0, SUM(SALARY)#18] test.employee ``` If we can partial push down the aggregate, then the plan will be: ``` Aggregate [DEPT#0], [DEPT#0, sum(cast(SUM(SALARY)#18 as decimal(20,2))) AS total#14] +- RelationV2[DEPT#0, SUM(SALARY)#18] test.employee ``` **The second example:** the origin plan show below: ``` Aggregate [myDept#33], [myDept#33, sum(mySalary#34) AS total#40] +- Project [DEPT#25 AS myDept#33, SALARY#27 AS mySalary#34] +- ScanBuilderHolder [DEPT#25, NAME#26, SALARY#27, BONUS#28], RelationV2[DEPT#25, NAME#26, SALARY#27, BONUS#28] test.employee, JDBCScanBuilder(org.apache.spark.sql.test.TestSparkSession25c4f621,StructType(StructField(DEPT,IntegerType,true),StructField(NAME,StringType,true),StructField(SALARY,DecimalType(20,2),true),StructField(BONUS,DoubleType,true)),org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions345d641e) ``` If we can complete push down the aggregate, then the plan will be: ``` Project [DEPT#25 AS myDept#33, SUM(SALARY)#44 AS sum(SALARY#27)#39 AS total#40] +- RelationV2[DEPT#25, SUM(SALARY)#44] test.employee ``` If we can partial push down the aggregate, then the plan will be: ``` Aggregate [myDept#33], [DEPT#25 AS myDept#33, sum(cast(SUM(SALARY)#56 as decimal(20,2))) AS total#52] +- RelationV2[DEPT#25, SUM(SALARY)#56] test.employee ``` ### Why are the changes needed? Alias is more useful. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users could see DS V2 aggregate push-down supports project with alias. ### How was this patch tested? New tests. Closes #35932 from beliefer/SPARK-38533_new. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../v2/V2ScanRelationPushDown.scala | 22 +++-- .../FileSourceAggregatePushDownSuite.scala | 4 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 86 +++++++++++++++++-- 3 files changed, 97 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 171110cc027ce..a4b5d26699495 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule @@ -34,7 +35,7 @@ import org.apache.spark.sql.sources import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} import org.apache.spark.sql.util.SchemaUtils._ -object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { @@ -86,22 +87,27 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => + if filters.isEmpty && CollapseProject.canCollapseExpressions( + resultExpressions, project, alwaysInline = true) => sHolder.builder match { case r: SupportsPushDownAggregates => + val aliasMap = getAliasMap(project) + val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap)) + val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap)) + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal) + val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( - groupingExpressions, sHolder.relation.output) + actualGroupExprs, sHolder.relation.output) val translatedAggregates = DataSourceStrategy.translateAggregation( normalizedAggregates, normalizedGroupingExpressions) val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { if (translatedAggregates.isEmpty || r.supportCompletePushDown(translatedAggregates.get) || translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { - (resultExpressions, aggregates, translatedAggregates) + (actualResultExprs, aggregates, translatedAggregates) } else { // scalastyle:off // The data source doesn't support the complete push-down of this aggregation. @@ -118,7 +124,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] // +- ScanOperation[...] // scalastyle:on - val newResultExpressions = resultExpressions.map { expr => + val newResultExpressions = actualResultExprs.map { expr => expr.transform { case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) @@ -197,7 +203,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) if (r.supportCompletePushDown(pushedAggregates.get)) { - val projectExpressions = resultExpressions.map { expr => + val projectExpressions = finalResultExpressions.map { expr => // TODO At present, only push down group by attribute is supported. // In future, more attribute conversion is extended here. e.g. GetStructField expr.transform { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index f8cb77757682a..c787493fbdcc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -184,7 +184,7 @@ trait FileSourceAggregatePushDownSuite } } - test("aggregate over alias not push down") { + test("aggregate over alias push down") { val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), (9, "mno", 7), (2, null, 6)) withDataSourceTable(data, "t") { @@ -194,7 +194,7 @@ trait FileSourceAggregatePushDownSuite query.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: []" // aggregate alias not pushed down + "PushedAggregation: [MIN(_1)]" checkKeywordsExistsInExplain(query, expected_plan_fragment) } checkAnswer(query, Seq(Row(-2))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index f42fadf08dcc7..60b19ec4a6081 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -995,15 +995,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } - test("scan with aggregate push-down: aggregate over alias NOT push down") { + test("scan with aggregate push-down: aggregate over alias push down") { val cols = Seq("a", "b", "c", "d", "e") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") - checkAggregateRemoved(df2, false) + checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.aggregation.isEmpty) + case relation: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.aggregation.nonEmpty) } } checkAnswer(df2, Seq(Row(53000.00))) @@ -1231,4 +1235,76 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } checkAnswer(query, Seq(Row(29000.0))) } + + test("scan with aggregate push-down: complete push-down aggregate with alias") { + val df = spark.table("h2.test.employee") + .select($"DEPT", $"SALARY".as("mySalary")) + .groupBy($"DEPT") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df, expectedPlanFragment) + } + checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + + val df2 = spark.table("h2.test.employee") + .select($"DEPT".as("myDept"), $"SALARY".as("mySalary")) + .groupBy($"myDept") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + } + checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + } + + test("scan with aggregate push-down: partial push-down aggregate with alias") { + val df = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME", $"SALARY".as("mySalary")) + .groupBy($"NAME") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" + checkKeywordsExistsInExplain(df, expectedPlanFragment) + } + checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + + val df2 = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME".as("myName"), $"SALARY".as("mySalary")) + .groupBy($"myName") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2, false) + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + } + checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + } } From 219eb4f99a98bda4d81bf303e7e9789a29f302e6 Mon Sep 17 00:00:00 2001 From: chenzhx Date: Wed, 23 Mar 2022 23:30:35 +0800 Subject: [PATCH 41/53] code foramt --- .../main/scala/org/apache/spark/sql/sources/filters.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index c4d9be95e97ba..e358ff0cb6677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -376,4 +376,9 @@ object AlwaysFalse extends AlwaysFalse { @Evolving case class Trivial(value: Boolean) extends Filter { override def references: Array[String] = findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("TRIVIAL", + Array(LiteralValue(literal.value, literal.dataType))) + } } From cab02669d1d7c6ac07f812de882bbf5c4a8b5908 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 23 Mar 2022 21:40:44 +0800 Subject: [PATCH 42/53] [SPARK-37483][SQL][FOLLOWUP] Rename `pushedTopN` to `PushedTopN` and improve JDBCV2Suite ### What changes were proposed in this pull request? This PR fix three issues. **First**, create method `checkPushedInfo` and `checkSortRemoved` to reuse code. **Second**, remove method `checkPushedLimit`, because `checkPushedInfo` can cover it. **Third**, rename `pushedTopN` to `PushedTopN`, so as consistent with other pushed information. ### Why are the changes needed? Reuse code and let pushed information more correctly. ### Does this PR introduce _any_ user-facing change? 'No'. New feature and improve the tests. ### How was this patch tested? Adjust existing tests. Closes #35921 from beliefer/SPARK-37483_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../sql/execution/DataSourceScanExec.scala | 2 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 595 +++++------------- 2 files changed, 169 insertions(+), 428 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index bb5b8e32aef63..432775c9045ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -147,7 +147,7 @@ case class RowDataSourceScanExec( val pushedTopN = s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" + s" LIMIT ${pushedDownOperators.limit.get}" - Some("pushedTopN" -> pushedTopN) + Some("PushedTopN" -> pushedTopN) } else { pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 60b19ec4a6081..543c52e2704e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -24,7 +24,6 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} -import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, sum, udf, when} @@ -110,23 +109,35 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2))) } + private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String): Unit = { + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + checkKeywordsExistsInExplain(df, expectedPlanFragment) + } + } + // TABLESAMPLE ({integer_expression | decimal_expression} PERCENT) and // TABLESAMPLE (BUCKET integer_expression OUT OF integer_expression) // are tested in JDBC dialect tests because TABLESAMPLE is not supported by all the DBMS test("TABLESAMPLE (integer_expression ROWS) is the same as LIMIT") { val df = sql("SELECT NAME FROM h2.test.employee TABLESAMPLE (3 ROWS)") + checkSchemaNames(df, Seq("NAME")) + checkPushedInfo(df, "PushedFilters: [], PushedLimit: LIMIT 3, ") + checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) + } + + private def checkSchemaNames(df: DataFrame, names: Seq[String]): Unit = { val scan = df.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get - assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df, Some(3)) - checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) + assert(scan.schema.names.sameElements(names)) } test("simple scan with LIMIT") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) - checkPushedLimit(df1, Some(1)) + checkPushedInfo(df1, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 1, ") checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0, true))) val df2 = spark.read @@ -137,22 +148,22 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .filter($"dept" > 1) .limit(1) - checkPushedLimit(df2, Some(1)) + checkPushedInfo(df2, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ") checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") - val scan = df3.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df3, Some(1)) + checkSchemaNames(df3, Seq("NAME")) + checkPushedInfo(df3, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ") checkAnswer(df3, Seq(Row("alex"))) val df4 = spark.read .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) - checkPushedLimit(df4, None) + checkPushedInfo(df4, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT], ") checkAnswer(df4, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -163,24 +174,18 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter(name($"shortName")) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedLimit(df5, None) + checkPushedInfo(df5, "PushedFilters: [], ") checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkPushedLimit(df: DataFrame, limit: Option[Int] = None, - sortValues: Seq[SortValue] = Nil): Unit = { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.limit === limit) - assert(v1.pushedDownOperators.sortValues === sortValues) - } + private def checkSortRemoved(df: DataFrame, removed: Boolean = true): Unit = { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s } - if (sortValues.nonEmpty) { - val sorts = df.queryExecution.optimizedPlan.collect { - case s: Sort => s - } + if (removed) { assert(sorts.isEmpty) + } else { + assert(sorts.nonEmpty) } } @@ -189,12 +194,16 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .sort("salary") .limit(1) - checkPushedLimit(df1, Some(1), createSortValues()) + checkSortRemoved(df1) + checkPushedInfo(df1, + "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df2 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary").limit(1) - checkPushedLimit(df2, Some(1), createSortValues()) + checkSortRemoved(df2) + checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + + "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) val df3 = spark.read @@ -206,22 +215,23 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .orderBy($"salary".desc) .limit(1) - checkPushedLimit( - df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) + checkSortRemoved(df3) + checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + + "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ") checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df4 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") - val scan = df4.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) + checkSchemaNames(df4, Seq("NAME")) + checkSortRemoved(df4) + checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + + "PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ") checkAnswer(df4, Seq(Row("david"))) val df5 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary") - checkPushedLimit(df5, None) + checkSortRemoved(df5, false) + checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ") checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true))) @@ -230,7 +240,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy("DEPT").sum("SALARY") .orderBy("DEPT") .limit(1) - checkPushedLimit(df6) + checkSortRemoved(df6, false) + checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," + + " PushedFilters: [], PushedGroupByColumns: [DEPT], ") checkAnswer(df6, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -242,147 +254,69 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort($"SALARY".desc) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedLimit(df7) + checkSortRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [], ") checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) val df8 = spark.read .table("h2.test.employee") .sort(sub($"NAME")) .limit(1) - checkPushedLimit(df8) + checkSortRemoved(df8, false) + checkPushedInfo(df8, "PushedFilters: [], ") checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) } - private def createSortValues( - sortDirection: SortDirection = SortDirection.ASCENDING, - nullOrdering: NullOrdering = NullOrdering.NULLS_FIRST): Seq[SortValue] = { - Seq(SortValue(FieldReference("salary"), sortDirection, nullOrdering)) - } - test("scan with filter push-down") { val df = spark.table("h2.test.people").filter($"id" > 1) - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.isEmpty) - - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [ID IS NOT NULL, ID > 1]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } - + checkFiltersRemoved(df) + checkPushedInfo(df, "PushedFilters: [ID IS NOT NULL, ID > 1], ") checkAnswer(df, Row("mary", 2)) val df2 = spark.table("h2.test.employee").filter($"name".isin("amy", "cathy")) - checkFiltersRemoved(df2) - - df2.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [NAME IN ('amy', 'cathy')]" - checkKeywordsExistsInExplain(df2, expected_plan_fragment) - } - + checkPushedInfo(df2, "PushedFilters: [NAME IN ('amy', 'cathy')]") checkAnswer(df2, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false))) val df3 = spark.table("h2.test.employee").filter($"name".startsWith("a")) - checkFiltersRemoved(df3) - - df3.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%']" - checkKeywordsExistsInExplain(df3, expected_plan_fragment) - } - + checkPushedInfo(df3, "PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%']") checkAnswer(df3, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false))) val df4 = spark.table("h2.test.employee").filter($"is_manager") - checkFiltersRemoved(df4) - - df4.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true]" - checkKeywordsExistsInExplain(df4, expected_plan_fragment) - } - + checkPushedInfo(df4, "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true]") checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) val df5 = spark.table("h2.test.employee").filter($"is_manager".and($"salary" > 10000)) - checkFiltersRemoved(df5) - - df5.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [IS_MANAGER IS NOT NULL, SALARY IS NOT NULL, " + - "IS_MANAGER = true, SALARY > 10000.00]" - checkKeywordsExistsInExplain(df5, expected_plan_fragment) - } - + checkPushedInfo(df5, "PushedFilters: [IS_MANAGER IS NOT NULL, SALARY IS NOT NULL, " + + "IS_MANAGER = true, SALARY > 10000.00]") checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true))) val df6 = spark.table("h2.test.employee").filter($"is_manager".or($"salary" > 10000)) - checkFiltersRemoved(df6) - - df6.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [(IS_MANAGER = true) OR (SALARY > 10000.00)], " - checkKeywordsExistsInExplain(df6, expected_plan_fragment) - } - + checkPushedInfo(df6, "PushedFilters: [(IS_MANAGER = true) OR (SALARY > 10000.00)], ") checkAnswer(df6, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false), Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) val df7 = spark.table("h2.test.employee").filter(not($"is_manager") === true) - checkFiltersRemoved(df7) - - df7.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [IS_MANAGER IS NOT NULL, NOT (IS_MANAGER = true)], " - checkKeywordsExistsInExplain(df7, expected_plan_fragment) - } - + checkPushedInfo(df7, "PushedFilters: [IS_MANAGER IS NOT NULL, NOT (IS_MANAGER = true)], ") checkAnswer(df7, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "alex", 12000, 1200, false))) val df8 = spark.table("h2.test.employee").filter($"is_manager" === true) - checkFiltersRemoved(df8) - - df8.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true], " - checkKeywordsExistsInExplain(df8, expected_plan_fragment) - } - + checkPushedInfo(df8, "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true], ") checkAnswer(df8, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) val df9 = spark.table("h2.test.employee") .filter(when($"dept" > 1, true).when($"is_manager", false).otherwise($"dept" > 3)) - checkFiltersRemoved(df9) - - df9.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [CASE WHEN DEPT > 1 THEN TRUE WHEN IS_MANAGER = true THEN FALSE" + - " ELSE DEPT > 3 END], " - checkKeywordsExistsInExplain(df9, expected_plan_fragment) - } - + checkPushedInfo(df9, "PushedFilters: [CASE WHEN DEPT > 1 THEN TRUE " + + "WHEN IS_MANAGER = true THEN FALSE ELSE DEPT > 3 END], ") checkAnswer(df9, Seq(Row(2, "alex", 12000, 1200, false), Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) } @@ -391,19 +325,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Seq(false, true).foreach { ansiMode => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val df = spark.table("h2.test.people").filter($"id" + 1 > 1) - checkFiltersRemoved(df, ansiMode) - - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = if (ansiMode) { - "PushedFilters: [ID IS NOT NULL, (ID + 1) > 1]" - } else { - "PushedFilters: [ID IS NOT NULL]" - } - checkKeywordsExistsInExplain(df, expected_plan_fragment) + val expectedPlanFragment = if (ansiMode) { + "PushedFilters: [ID IS NOT NULL, (ID + 1) > 1]" + } else { + "PushedFilters: [ID IS NOT NULL]" } - + checkPushedInfo(df, expectedPlanFragment) checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2))) val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1) @@ -436,18 +364,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel |""".stripMargin) checkFiltersRemoved(df3, ansiMode) - - df3.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = if (ansiMode) { - "PushedFilters: [(CASE WHEN SALARY > 10000.00 THEN BONUS" + - " ELSE BONUS + 200.0 END) > 1200.0]" - } else { - "PushedFilters: []" - } - checkKeywordsExistsInExplain(df3, expected_plan_fragment) + val expectedPlanFragment3 = if (ansiMode) { + "PushedFilters: [(CASE WHEN SALARY > 10000.00 THEN BONUS" + + " ELSE BONUS + 200.0 END) > 1200.0]" + } else { + "PushedFilters: []" } - + checkPushedInfo(df3, expectedPlanFragment3) checkAnswer(df3, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) } @@ -456,23 +379,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with column pruning") { val df = spark.table("h2.test.people").select("id") - val scan = df.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq("ID"))) + checkSchemaNames(df, Seq("ID")) checkAnswer(df, Seq(Row(1), Row(2))) } test("scan with filter push-down and column pruning") { val df = spark.table("h2.test.people").filter($"id" > 1).select("name") - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.isEmpty) - val scan = df.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq("NAME"))) + checkFiltersRemoved(df) + checkSchemaNames(df, Seq("NAME")) checkAnswer(df, Row("mary")) } @@ -597,19 +511,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: MAX AVG with filter and group by") { val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt") - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.isEmpty) + checkFiltersRemoved(df) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], ") checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) } @@ -626,19 +532,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: MAX AVG with filter without group by") { val df = sql("select MAX(ID), AVG(ID) FROM h2.test.people where id > 0") - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.isEmpty) + checkFiltersRemoved(df) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(ID), AVG(ID)], " + - "PushedFilters: [ID IS NOT NULL, ID > 0], " + - "PushedGroupByColumns: [], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [MAX(ID), AVG(ID)], " + + "PushedFilters: [ID IS NOT NULL, ID > 0], " + + "PushedGroupByColumns: [], ") checkAnswer(df, Seq(Row(2, 1.5))) } @@ -668,42 +566,28 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "PushedAggregates: [MAX(SALARY)]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } + checkPushedInfo(df, "PushedAggregates: [MAX(SALARY)]") checkAnswer(df, Seq(Row(12001))) } test("scan with aggregate push-down: COUNT(*)") { val df = sql("select COUNT(*) FROM h2.test.employee") checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(*)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [COUNT(*)]") checkAnswer(df, Seq(Row(5))) } test("scan with aggregate push-down: COUNT(col)") { val df = sql("select COUNT(DEPT) FROM h2.test.employee") checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(DEPT)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [COUNT(DEPT)]") checkAnswer(df, Seq(Row(5))) } test("scan with aggregate push-down: COUNT(DISTINCT col)") { val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee") checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(DISTINCT DEPT)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [COUNT(DISTINCT DEPT)]") checkAnswer(df, Seq(Row(3))) } @@ -722,71 +606,40 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: SUM without filer and group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee") checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)]") checkAnswer(df, Seq(Row(53000))) } test("scan with aggregate push-down: DISTINCT SUM without filer and group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee") checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(DISTINCT SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)]") checkAnswer(df, Seq(Row(31000))) } test("scan with aggregate push-down: SUM with group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)], " + + "PushedFilters: [], PushedGroupByColumns: [DEPT], ") checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } test("scan with aggregate push-down: DISTINCT SUM with group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT") checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(DISTINCT SALARY)], " + - "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)], " + + "PushedFilters: [], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } test("scan with aggregate push-down: with multiple group by columns") { val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DEPT, NAME") - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.isEmpty) + checkFiltersRemoved(df) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT, NAME], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), Row(10000, 1000), Row(12000, 1200))) } @@ -799,14 +652,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } assert(filters1.isEmpty) checkAggregateRemoved(df1) - df1.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT, NAME], " - checkKeywordsExistsInExplain(df1, expected_plan_fragment) - } + checkPushedInfo(df1, "PushedAggregates: [MAX(SALARY)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), Row("2#david", 10000), Row("6#jen", 12000))) @@ -817,30 +664,16 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } assert(filters2.isEmpty) checkAggregateRemoved(df2) - df2.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT, NAME], " - checkKeywordsExistsInExplain(df2, expected_plan_fragment) - } + checkPushedInfo(df2, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) val df3 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" + " FROM h2.test.employee where dept > 0 group by concat_ws('#', DEPT, NAME)") - val filters3 = df3.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters3.isEmpty) + checkFiltersRemoved(df3) checkAggregateRemoved(df3, false) - df3.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " - checkKeywordsExistsInExplain(df3, expected_plan_fragment) - } + checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], ") checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), Row("2#david", 11300), Row("6#jen", 13200))) } @@ -848,19 +681,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("scan with aggregate push-down: with having clause") { val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DEPT having MIN(BONUS) > 1000") - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f // filter over aggregate not push down - } - assert(filters.nonEmpty) + // filter over aggregate not push down + checkFiltersRemoved(df, false) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) } @@ -869,14 +694,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy($"DEPT") .min("SALARY").as("total") checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MIN(SALARY)], " + - "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [MIN(SALARY)], " + + "PushedFilters: [], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) } @@ -888,19 +707,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .agg(sum($"SALARY").as("total")) .filter($"total" > 1000) .orderBy($"total") - val filters = query.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.nonEmpty) // filter over aggregate not pushed down - checkAggregateRemoved(df) - query.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(query, expected_plan_fragment) - } + checkFiltersRemoved(query, false)// filter over aggregate not pushed down + checkAggregateRemoved(query) + checkPushedInfo(query, "PushedAggregates: [SUM(SALARY)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) } @@ -909,12 +719,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val decrease = udf { (x: Double, y: Double) => x - y } val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value")) checkAggregateRemoved(query) - query.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), SUM(BONUS)]" - checkKeywordsExistsInExplain(query, expected_plan_fragment) - } + checkPushedInfo(query, "PushedAggregates: [SUM(SALARY), SUM(BONUS)], ") checkAnswer(query, Seq(Row(47100.0))) } @@ -936,14 +741,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " group by DePt") checkFiltersRemoved(df) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -952,14 +751,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " where dept > 0 group by DePt") checkFiltersRemoved(df) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) } @@ -968,14 +761,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " FROM h2.test.employee where dept > 0 group by DePt") checkFiltersRemoved(df) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) } @@ -984,14 +771,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " group by DePt") checkFiltersRemoved(df) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [CORR(BONUS, BONUS)], " + - "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) } @@ -1053,15 +834,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel |FROM h2.test.employee GROUP BY DEPT """.stripMargin) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00)" + - " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " + - "PushedFilters: [], " + - "PushedGroupByColumns: [DEPT], " - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, + "PushedAggregates: [COUNT(CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00)" + + " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " + + "PushedFilters: [], " + + "PushedGroupByColumns: [DEPT], ") checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 2, 0d), Row(2, 2, 2, 2, 2, 0d, 10000d, 0d, 10000d, 10000d, 0d, 0d, 2, 0d), Row(2, 2, 2, 2, 2, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 3, 0d))) @@ -1072,17 +849,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") checkAggregateRemoved(df, ansiMode) - val expected_plan_fragment = if (ansiMode) { + val expectedPlanFragment = if (ansiMode) { "PushedAggregates: [SUM(2147483647 + DEPT)], " + "PushedFilters: [], " + "PushedGroupByColumns: []" } else { "PushedFilters: []" } - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, expectedPlanFragment) if (ansiMode) { val e = intercept[SparkException] { checkAnswer(df, Seq(Row(-10737418233L))) @@ -1101,12 +875,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val decrease = udf { (x: Double, y: Double) => x - y } val query = df.select(sum(decrease($"SALARY", $"BONUS")).as("value")) checkAggregateRemoved(query, false) - query.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: []" - checkKeywordsExistsInExplain(query, expected_plan_fragment) - } + checkPushedInfo(query, "PushedFilters: []") checkAnswer(query, Seq(Row(47100.0))) } @@ -1138,6 +907,24 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1))) } + test("column name with composite field") { + checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2))) + val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept") + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [COUNT(`dept id`)]") + checkAnswer(df, Seq(Row(2))) + } + + test("column name with non-ascii") { + // scalastyle:off + checkAnswer(sql("SELECT `名` FROM h2.test.person"), Seq(Row(1), Row(2))) + val df = sql("SELECT COUNT(`名`) FROM h2.test.person") + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [COUNT(`名`)]") + checkAnswer(df, Seq(Row(2))) + // scalastyle:on + } + test("scan with aggregate push-down: complete push-down SUM, AVG, COUNT") { val df = spark.read .option("partitionColumn", "dept") @@ -1147,12 +934,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]") checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) val df2 = spark.read @@ -1164,12 +946,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy($"name") .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]") checkAnswer(df2, Seq( Row("alex", 12000.00, 12000.000000, 1), Row("amy", 10000.00, 10000.000000, 1), @@ -1187,12 +964,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) checkAggregateRemoved(df, false) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]") checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) val df2 = spark.read @@ -1204,12 +976,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy($"name") .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) checkAggregateRemoved(df, false) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]") checkAnswer(df2, Seq( Row("alex", 12000.00, 12000.000000, 1), Row("amy", 10000.00, 10000.000000, 1), @@ -1218,22 +985,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Row("jen", 12000.00, 12000.000000, 1))) } - test("scan with aggregate push-down: aggregate with partially pushed down filters" + - "will NOT push down") { - val df = spark.table("h2.test.employee") - val name = udf { (x: String) => x.matches("cat|dav|amy") } - val sub = udf { (x: String) => x.substring(0, 3) } - val query = df.select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) - .filter("SALARY > 100") - .filter(name($"shortName")) - .agg(sum($"SALARY").as("sum_salary")) - query.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: []" - checkKeywordsExistsInExplain(query, expected_plan_fragment) - } - checkAnswer(query, Seq(Row(29000.0))) + test("SPARK-37895: JDBC push down with delimited special identifiers") { + val df = sql( + """SELECT h2.test.view1.`|col1`, h2.test.view1.`|col2`, h2.test.view2.`|col3` + |FROM h2.test.view1 LEFT JOIN h2.test.view2 + |ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin) + checkAnswer(df, Seq.empty[Row]) } test("scan with aggregate push-down: complete push-down aggregate with alias") { @@ -1243,12 +1000,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .agg(sum($"mySalary").as("total")) .filter($"total" > 1000) checkAggregateRemoved(df) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expectedPlanFragment = - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df, expectedPlanFragment) - } + checkPushedInfo(df, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) val df2 = spark.table("h2.test.employee") @@ -1257,12 +1010,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .agg(sum($"mySalary").as("total")) .filter($"total" > 1000) checkAggregateRemoved(df2) - df2.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expectedPlanFragment = - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]" - checkKeywordsExistsInExplain(df2, expectedPlanFragment) - } + checkPushedInfo(df2, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]") checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) } @@ -1278,12 +1027,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .agg(sum($"mySalary").as("total")) .filter($"total" > 1000) checkAggregateRemoved(df, false) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expectedPlanFragment = - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" - checkKeywordsExistsInExplain(df, expectedPlanFragment) - } + checkPushedInfo(df, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]") checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) @@ -1298,12 +1043,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .agg(sum($"mySalary").as("total")) .filter($"total" > 1000) checkAggregateRemoved(df2, false) - df2.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expectedPlanFragment = - "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]" - checkKeywordsExistsInExplain(df2, expectedPlanFragment) - } + checkPushedInfo(df2, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]") checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) } From 9f3194c6ae7a6f2cc61cb4ef22220c8992deb354 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 25 Mar 2022 20:00:39 +0800 Subject: [PATCH 43/53] [SPARK-38644][SQL] DS V2 topN push-down supports project with alias ### What changes were proposed in this pull request? Currently, Spark DS V2 topN push-down doesn't supports project with alias. This PR let it works good with alias. **Example**: the origin plan show below: ``` Sort [mySalary#10 ASC NULLS FIRST], true +- Project [NAME#1, SALARY#2 AS mySalary#10] +- ScanBuilderHolder [DEPT#0, NAME#1, SALARY#2, BONUS#3, IS_MANAGER#4], RelationV2[DEPT#0, NAME#1, SALARY#2, BONUS#3, IS_MANAGER#4] test.employee, JDBCScanBuilder(org.apache.spark.sql.test.TestSparkSession7fd4b9ec,StructType(StructField(DEPT,IntegerType,true),StructField(NAME,StringType,true),StructField(SALARY,DecimalType(20,2),true),StructField(BONUS,DoubleType,true),StructField(IS_MANAGER,BooleanType,true)),org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions3c8e4a82) ``` The `pushedLimit` and `sortOrders` of `JDBCScanBuilder` are empty. If we can push down the top n, then the plan will be: ``` Project [NAME#1, SALARY#2 AS mySalary#10] +- ScanBuilderHolder [DEPT#0, NAME#1, SALARY#2, BONUS#3, IS_MANAGER#4], RelationV2[DEPT#0, NAME#1, SALARY#2, BONUS#3, IS_MANAGER#4] test.employee, JDBCScanBuilder(org.apache.spark.sql.test.TestSparkSession7fd4b9ec,StructType(StructField(DEPT,IntegerType,true),StructField(NAME,StringType,true),StructField(SALARY,DecimalType(20,2),true),StructField(BONUS,DoubleType,true),StructField(IS_MANAGER,BooleanType,true)),org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions3c8e4a82) ``` The `pushedLimit` of `JDBCScanBuilder` will be `1` and `sortOrders` of `JDBCScanBuilder` will be `SALARY ASC NULLS FIRST`. ### Why are the changes needed? Alias is more useful. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users could see DS V2 topN push-down supports project with alias. ### How was this patch tested? New tests. Closes #35961 from beliefer/SPARK-38644. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../v2/V2ScanRelationPushDown.scala | 15 +++++++----- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 24 +++++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index a4b5d26699495..116a1f95fd93e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} @@ -365,9 +365,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit sHolder.pushedLimit = Some(limit) } operation - case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) - if filter.isEmpty => - val orders = DataSourceStrategy.translateSortOrders(order) + case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) + if filter.isEmpty && CollapseProject.canCollapseExpressions( + order, project, alwaysInline = true) => + val aliasMap = getAliasMap(project) + val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] + val orders = DataSourceStrategy.translateSortOrders(newOrder) if (orders.length == order.length) { val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) if (topNPushed) { @@ -418,7 +421,7 @@ case class ScanBuilderHolder( builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None - var sortOrders: Seq[SortOrder] = Seq.empty[SortOrder] + var sortOrders: Seq[V2SortOrder] = Seq.empty[V2SortOrder] var pushedSample: Option[TableSampleInfo] = None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 543c52e2704e7..3309186cc3e5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -267,6 +267,30 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) } + test("simple scan with top N: order by with alias") { + val df1 = spark.read + .table("h2.test.employee") + .select($"NAME", $"SALARY".as("mySalary")) + .sort("mySalary") + .limit(1) + checkSortRemoved(df1) + checkPushedInfo(df1, + "PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") + checkAnswer(df1, Seq(Row("cathy", 9000.00))) + + val df2 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"NAME", $"SALARY".as("mySalary")) + .filter($"DEPT" > 1) + .sort("mySalary") + .limit(1) + checkSortRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") + checkAnswer(df2, Seq(Row(2, "david", 10000.00))) + } + test("scan with filter push-down") { val df = spark.table("h2.test.people").filter($"id" > 1) checkFiltersRemoved(df) From dbb8c2db14c6c68b911ee32825797fc3ce1cebd9 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 28 Mar 2022 13:25:46 +0800 Subject: [PATCH 44/53] [SPARK-38391][SQL] Datasource v2 supports partial topN push-down ### What changes were proposed in this pull request? Currently , Spark supports push down topN completely . But for some data source (e.g. JDBC ) that have multiple partition , we should preserve partial push down topN. ### Why are the changes needed? Make behavior of sort pushdown correctly. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the inner implement. ### How was this patch tested? New tests. Closes #35710 from beliefer/SPARK-38391. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../sql/connector/read/SupportsPushDownTopN.java | 6 ++++++ .../execution/datasources/v2/PushDownUtils.scala | 11 +++++++---- .../datasources/v2/V2ScanRelationPushDown.scala | 11 ++++++++--- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 2 ++ .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 13 ++++++++++--- 5 files changed, 33 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java index 0212895fde079..cba1592c4fa14 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -35,4 +35,10 @@ public interface SupportsPushDownTopN extends ScanBuilder { * Pushes down top N to the data source. */ boolean pushTopN(SortOrder[] orders, int limit); + + /** + * Whether the top N is partially pushed or not. If it returns true, then Spark will do top N + * again. This method will only be called when {@link #pushTopN} returns true. + */ + default boolean isPartiallyPushed() { return true; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 2f55b7ee46ac7..2adbd5cf007e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -130,11 +130,14 @@ object PushDownUtils extends PredicateHelper { /** * Pushes down top N to the data source Scan */ - def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = { + def pushTopN( + scanBuilder: ScanBuilder, + order: Array[SortOrder], + limit: Int): (Boolean, Boolean) = { scanBuilder match { - case s: SupportsPushDownTopN => - s.pushTopN(order, limit) - case _ => false + case s: SupportsPushDownTopN if s.pushTopN(order, limit) => + (true, s.isPartiallyPushed) + case _ => (false, false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 116a1f95fd93e..24e3a6c91b13d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -372,11 +372,16 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] val orders = DataSourceStrategy.translateSortOrders(newOrder) if (orders.length == order.length) { - val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) - if (topNPushed) { + val (isPushed, isPartiallyPushed) = + PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) + if (isPushed) { sHolder.pushedLimit = Some(limit) sHolder.sortOrders = orders - operation + if (isPartiallyPushed) { + s + } else { + operation + } } else { s } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 475f563856f82..0a1542a42956d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -146,6 +146,8 @@ case class JDBCScanBuilder( false } + override def isPartiallyPushed(): Boolean = jdbcOptions.numPartitions.map(_ > 1).getOrElse(false) + override def pruneColumns(requiredSchema: StructType): Unit = { // JDBC doesn't support nested column pruning. // TODO (SPARK-32593): JDBC support nested column and nested column pruning. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 3309186cc3e5b..8a2409599c584 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -199,8 +199,15 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) - val df2 = spark.read.table("h2.test.employee") - .where($"dept" === 1).orderBy($"salary").limit(1) + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .where($"dept" === 1) + .orderBy($"salary") + .limit(1) checkSortRemoved(df2) checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") @@ -215,7 +222,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .orderBy($"salary".desc) .limit(1) - checkSortRemoved(df3) + checkSortRemoved(df3, false) checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ") checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false))) From b67333df6b45fd50fa0f5273825b2822c84de677 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 29 Mar 2022 11:25:58 +0800 Subject: [PATCH 45/53] [SPARK-38633][SQL] Support push down Cast to JDBC data source V2 ### What changes were proposed in this pull request? Cast is very useful and Spark always use Cast to convert data type automatically. ### Why are the changes needed? Let more aggregates and filters could be pushed down. ### Does this PR introduce _any_ user-facing change? 'Yes'. This PR after cut off 3.3.0. ### How was this patch tested? New tests. Closes #35947 from beliefer/SPARK-38633. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../spark/sql/connector/expressions/Cast.java | 45 +++++++++++++++++++ .../util/V2ExpressionSQLBuilder.java | 9 ++++ .../catalyst/util/V2ExpressionBuilder.scala | 6 ++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 21 ++++++++- 4 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java new file mode 100644 index 0000000000000..26b97b46fe2ef --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.DataType; + +/** + * Represents a cast expression in the public logical expression API. + * + * @since 3.3.0 + */ +@Evolving +public class Cast implements Expression, Serializable { + private Expression expression; + private DataType dataType; + + public Cast(Expression expression, DataType dataType) { + this.expression = expression; + this.dataType = dataType; + } + + public Expression expression() { return expression; } + public DataType dataType() { return dataType; } + + @Override + public Expression[] children() { return new Expression[]{ expression() }; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 1df01d29cbdd1..c8d924db75aed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -21,10 +21,12 @@ import java.util.List; import java.util.stream.Collectors; +import org.apache.spark.sql.connector.expressions.Cast; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.types.DataType; /** * The builder to generate SQL from V2 expressions. @@ -36,6 +38,9 @@ public String build(Expression expr) { return visitLiteral((Literal) expr); } else if (expr instanceof NamedReference) { return visitNamedReference((NamedReference) expr); + } else if (expr instanceof Cast) { + Cast cast = (Cast) expr; + return visitCast(build(cast.expression()), cast.dataType()); } else if (expr instanceof GeneralScalarExpression) { GeneralScalarExpression e = (GeneralScalarExpression) expr; String name = e.name(); @@ -167,6 +172,10 @@ protected String visitBinaryArithmetic(String name, String l, String r) { return l + " " + name + " " + r; } + protected String visitCast(String l, DataType dataType) { + return "CAST(" + l + " AS " + dataType.typeName() + ")"; + } + protected String visitAnd(String name, String l, String r) { return "(" + l + ") " + name + " (" + r + ")"; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 392314d473166..f12305d971209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn import org.apache.spark.sql.types.BooleanType @@ -94,6 +94,8 @@ class V2ExpressionBuilder( } else { None } + case Cast(child, dataType, _, true) => + generateExpression(child).map(v => new V2Cast(v, dataType)) case and: And => // AND expects predicate val l = generateExpression(and.left, true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 8a2409599c584..28f9c533cd1d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -352,7 +352,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) } - test("scan with complex filter push-down") { + test("scan with filter push-down with ansi mode") { Seq(false, true).foreach { ansiMode => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val df = spark.table("h2.test.people").filter($"id" + 1 > 1) @@ -404,6 +404,25 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df3, expectedPlanFragment3) checkAnswer(df3, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + + val df4 = spark.table("h2.test.employee") + .filter(($"salary" > 1000d).and($"salary" < 12000d)) + + checkFiltersRemoved(df4, ansiMode) + + df4.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = if (ansiMode) { + "PushedFilters: [SALARY IS NOT NULL, " + + "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], " + } else { + "PushedFilters: [SALARY IS NOT NULL], " + } + checkKeywordsExistsInExplain(df4, expected_plan_fragment) + } + + checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), + Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) } } } From e6cfc550a4400a0771c5476a57266e5926dc1059 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 28 Mar 2022 22:06:05 +0800 Subject: [PATCH 46/53] [SPARK-38432][SQL][FOLLOWUP] Add test case for push down filter with alias ### What changes were proposed in this pull request? DS V2 pushdown predicates to data source supports column with alias. But Spark missing the test case for push down filter with alias. ### Why are the changes needed? Add test case for push down filter with alias ### Does this PR introduce _any_ user-facing change? 'No'. Just add a test case. ### How was this patch tested? New tests. Closes #35988 from beliefer/SPARK-38432_followup2. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 28f9c533cd1d1..8553774055665 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -350,6 +350,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel "WHEN IS_MANAGER = true THEN FALSE ELSE DEPT > 3 END], ") checkAnswer(df9, Seq(Row(2, "alex", 12000, 1200, false), Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) + + val df10 = spark.table("h2.test.people") + .select($"NAME".as("myName"), $"ID".as("myID")) + .filter($"myID" > 1) + checkFiltersRemoved(df10) + checkPushedInfo(df10, "PushedFilters: [ID IS NOT NULL, ID > 1], ") + checkAnswer(df10, Row("mary", 2)) } test("scan with filter push-down with ansi mode") { From 614cb93d452b513c9aaf837a32fe4ddb141287c0 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 29 Mar 2022 16:24:23 +0800 Subject: [PATCH 47/53] [SPARK-38633][SQL][FOLLOWUP] JDBCSQLBuilder should build cast to type of databases ### What changes were proposed in this pull request? DS V2 supports push down CAST to database. The current implement only uses the typeName of DataType. For example: `Cast(column, StringType)` will be build to `CAST(column AS String)`. But it should be `CAST(column AS TEXT)` for Postgres or `CAST(column AS VARCHAR2(255))` for Oracle. ### Why are the changes needed? Improve the implement of push down CAST. ### Does this PR introduce _any_ user-facing change? 'No'. Just new feature. ### How was this patch tested? Exists tests Closes #35999 from beliefer/SPARK-38633_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f9f90d8fb52b9..397942d7837db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -234,6 +234,12 @@ abstract class JdbcDialect extends Serializable with Logging{ } quoteIdentifier(namedRef.fieldNames.head) } + + override def visitCast(l: String, dataType: DataType): String = { + val databaseTypeDefinition = + getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName) + s"CAST($l AS $databaseTypeDefinition)" + } } /** From 30340707526eabfb1ed74bad118d9bc107622216 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 31 Mar 2022 19:18:58 +0800 Subject: [PATCH 48/53] [SPARK-37839][SQL][FOLLOWUP] Check overflow when DS V2 partial aggregate push-down `AVG` ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/35130 supports partial aggregate push-down `AVG` for DS V2. The behavior doesn't consistent with `Average` if occurs overflow in ansi mode. This PR closely follows the implement of `Average` to respect overflow in ansi mode. ### Why are the changes needed? Make the behavior consistent with `Average` if occurs overflow in ansi mode. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users could see the exception about overflow throws in ansi mode. ### How was this patch tested? New tests. Closes #35320 from beliefer/SPARK-37839_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../expressions/aggregate/Average.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 21 +++------ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 44 ++++++++++++++++++- 3 files changed, 52 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 05f7edaeb5d48..533f7f20b2530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -76,8 +76,8 @@ case class Average( case _ => DoubleType } - private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val count = AttributeReference("count", LongType)() + lazy val sum = AttributeReference("sum", sumDataType)() + lazy val count = AttributeReference("count", LongType)() override lazy val aggBufferAttributes = sum :: count :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 24e3a6c91b13d..cdcae15ef4e24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType} +import org.apache.spark.sql.types.{DataType, LongType, StructType} import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { @@ -129,18 +129,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) - // Closely follow `Average.evaluateExpression` - avg.dataType match { - case _: YearMonthIntervalType => - If(EqualTo(count, Literal(0L)), - Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) - case _: DayTimeIntervalType => - If(EqualTo(count, Literal(0L)), - Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) - case _ => - // TODO deal with the overflow issue - Divide(addCastIfNeeded(sum, avg.dataType), - addCastIfNeeded(count, avg.dataType), false) + avg.evaluateExpression transform { + case a: Attribute if a.semanticEquals(avg.sum) => + addCastIfNeeded(sum, avg.sum.dataType) + case a: Attribute if a.semanticEquals(avg.count) => + addCastIfNeeded(count, avg.count.dataType) } } }.asInstanceOf[Seq[NamedExpression]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 8553774055665..67a02904660c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -95,6 +95,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate() conn.prepareStatement( """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate() + + conn.prepareStatement( + "CREATE TABLE \"test\".\"item\" (id INTEGER, name TEXT(32), price NUMERIC(23, 3))") + .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " + + "(1, 'bottle', 11111111111111111111.123)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " + + "(1, 'bottle', 99999999999999999999.123)").executeUpdate() } } @@ -484,7 +492,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), Seq(Row("test", "people", false), Row("test", "empty_table", false), - Row("test", "employee", false))) + Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false), + Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false))) } test("SQL API: create table as select") { @@ -1105,4 +1114,37 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) } + + test("scan with aggregate push-down: partial push-down AVG with overflow") { + def createDataFrame: DataFrame = spark.read + .option("partitionColumn", "id") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.item") + .agg(avg($"PRICE").as("avg")) + + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df = createDataFrame + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + if (ansiEnabled) { + val e = intercept[SparkException] { + df.collect() + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals")) + } else { + checkAnswer(df, Seq(Row(null))) + } + } + } + } } From 93690a0a8adddf93880bae13b1258e74b94ac369 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 1 Apr 2022 13:34:41 +0800 Subject: [PATCH 49/53] [SPARK-37960][SQL][FOLLOWUP] Make the testing CASE WHEN query more reasonable ### What changes were proposed in this pull request? Some testing CASE WHEN queries are not carefully written and do not make sense. In the future, the optimizer may get smarter and get rid of the CASE WHEN completely, and then we loose test coverage. This PR updates some CASE WHEN queries to make them more reasonable. ### Why are the changes needed? future-proof test coverage. ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? N/A Closes #36032 from beliefer/SPARK-37960_followup2. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 67a02904660c3..6a0a55b77881e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -888,13 +888,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel | COUNT(CASE WHEN SALARY > 11000 OR SALARY < 10000 THEN SALARY ELSE 0 END), | COUNT(CASE WHEN SALARY >= 12000 OR SALARY < 9000 THEN SALARY ELSE 0 END), | COUNT(CASE WHEN SALARY >= 12000 OR NOT(SALARY >= 9000) THEN SALARY ELSE 0 END), - | MAX(CASE WHEN NOT(SALARY > 8000) AND SALARY >= 8000 THEN SALARY ELSE 0 END), - | MAX(CASE WHEN NOT(SALARY > 8000) OR SALARY > 8000 THEN SALARY ELSE 0 END), - | MAX(CASE WHEN NOT(SALARY > 8000) AND NOT(SALARY < 8000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 10000) AND SALARY >= 8000 THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 10000) OR SALARY > 8000 THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 10000) AND NOT(SALARY < 8000) THEN SALARY ELSE 0 END), | MAX(CASE WHEN NOT(SALARY != 0) OR NOT(SALARY < 8000) THEN SALARY ELSE 0 END), | MAX(CASE WHEN NOT(SALARY > 8000 AND SALARY > 8000) THEN 0 ELSE SALARY END), | MIN(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NULL) THEN SALARY ELSE 0 END), - | SUM(CASE WHEN NOT(SALARY > 8000 AND SALARY IS NOT NULL) THEN SALARY ELSE 0 END), | SUM(CASE WHEN SALARY > 10000 THEN 2 WHEN SALARY > 8000 THEN 1 END), | AVG(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NOT NULL) THEN SALARY ELSE 0 END) |FROM h2.test.employee GROUP BY DEPT @@ -905,9 +904,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " + "PushedFilters: [], " + "PushedGroupByColumns: [DEPT], ") - checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 2, 0d), - Row(2, 2, 2, 2, 2, 0d, 10000d, 0d, 10000d, 10000d, 0d, 0d, 2, 0d), - Row(2, 2, 2, 2, 2, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 3, 0d))) + checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 2, 0d), + Row(2, 2, 2, 2, 2, 10000d, 10000d, 10000d, 10000d, 10000d, 0d, 2, 0d), + Row(2, 2, 2, 2, 2, 10000d, 12000d, 10000d, 12000d, 12000d, 0d, 3, 0d))) } test("scan with aggregate push-down: aggregate function with binary arithmetic") { From a730da901cd451aa7415b8bb185bcf867a5dab8c Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 11 Apr 2022 13:50:57 +0800 Subject: [PATCH 50/53] [SPARK-38761][SQL] DS V2 supports push down misc non-aggregate functions ### What changes were proposed in this pull request? Currently, Spark have some misc non-aggregate functions of ANSI standard. Please refer https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L362. These functions show below: `abs`, `coalesce`, `nullif`, `CASE WHEN` DS V2 should supports push down these misc non-aggregate functions. Because DS V2 already support push down `CASE WHEN`, so this PR no need do the job again. Because `nullif` extends `RuntimeReplaceable`, so this PR no need do the job too. ### Why are the changes needed? DS V2 supports push down misc non-aggregate functions ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #36039 from beliefer/SPARK-38761. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../util/V2ExpressionSQLBuilder.java | 8 +++ .../catalyst/util/V2ExpressionBuilder.scala | 11 +++- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 50 ++++++++++--------- 3 files changed, 44 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index c8d924db75aed..a7d1ed7f85e84 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -93,6 +93,10 @@ public String build(Expression expr) { return visitNot(build(e.children()[0])); case "~": return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); + case "ABS": + case "COALESCE": + return visitSQLFunction(name, + Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { List children = Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); @@ -210,6 +214,10 @@ protected String visitCaseWhen(String[] children) { return sb.toString(); } + protected String visitSQLFunction(String funcName, String[] inputs) { + return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; + } + protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { throw new IllegalArgumentException("Unexpected V2 expression: " + expr); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index f12305d971209..9e852e7dd59fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn @@ -96,6 +96,15 @@ class V2ExpressionBuilder( } case Cast(child, dataType, _, true) => generateExpression(child).map(v => new V2Cast(v, dataType)) + case Abs(child, true) => generateExpression(child) + .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v))) + case Coalesce(children) => + val childrenExpressions = children.flatMap(generateExpression(_)) + if (children.length == childrenExpressions.length) { + Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression])) + } else { + None + } case and: And => // AND expects predicate val l = generateExpression(and.left, true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 6a0a55b77881e..d99497822a2b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, sum, udf, when} +import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -381,19 +381,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2))) val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1) - checkFiltersRemoved(df2, ansiMode) - - df2.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = if (ansiMode) { - "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], " - } else { - "PushedFilters: [ID IS NOT NULL], " - } - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + val expectedPlanFragment2 = if (ansiMode) { + "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], " + } else { + "PushedFilters: [ID IS NOT NULL], " } - + checkPushedInfo(df2, expectedPlanFragment2) if (ansiMode) { val e = intercept[SparkException] { checkAnswer(df2, Seq.empty) @@ -422,22 +416,30 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df4 = spark.table("h2.test.employee") .filter(($"salary" > 1000d).and($"salary" < 12000d)) - checkFiltersRemoved(df4, ansiMode) - - df4.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = if (ansiMode) { - "PushedFilters: [SALARY IS NOT NULL, " + - "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], " - } else { - "PushedFilters: [SALARY IS NOT NULL], " - } - checkKeywordsExistsInExplain(df4, expected_plan_fragment) + val expectedPlanFragment4 = if (ansiMode) { + "PushedFilters: [SALARY IS NOT NULL, " + + "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], " + } else { + "PushedFilters: [SALARY IS NOT NULL], " } - + checkPushedInfo(df4, expectedPlanFragment4) checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + + val df5 = spark.table("h2.test.employee") + .filter(abs($"dept" - 3) > 1) + .filter(coalesce($"salary", $"bonus") > 2000) + checkFiltersRemoved(df5, ansiMode) + val expectedPlanFragment5 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, ABS(DEPT - 3) > 1, " + + "(COALESCE(CAST(SALARY AS double), BONUS)) > 2000.0]" + } else { + "PushedFilters: [DEPT IS NOT NULL]" + } + checkPushedInfo(df5, expectedPlanFragment5) + checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true), + Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true))) } } } From 95607854d8692564b2c41c99f7750361cf87ef23 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 12 Apr 2022 22:17:54 -0700 Subject: [PATCH 51/53] [SPARK-38865][SQL][DOCS] Update document of JDBC options for `pushDownAggregate` and `pushDownLimit` ### What changes were proposed in this pull request? Because the DS v2 pushdown framework refactored, we need to add more doc in `sql-data-sources-jdbc.md` to reflect the new changes. ### Why are the changes needed? Add doc for new changes for `pushDownAggregate` and `pushDownLimit`. ### Does this PR introduce _any_ user-facing change? 'No'. Updated for new feature. ### How was this patch tested? N/A Closes #36152 from beliefer/SPARK-38865. Authored-by: Jiaan Geng Signed-off-by: huaxingao --- docs/sql-data-sources-jdbc.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 99e1a963a7954..e9af0ba274d7b 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -9,9 +9,9 @@ license: | The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -191,7 +191,7 @@ logging into the data sources. write - + cascadeTruncate the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect @@ -241,7 +241,7 @@ logging into the data sources. pushDownAggregate false - The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the related filters can be pushed down. Spark assumes that the data source can't fully complete the aggregate and does a final aggregate over the data source output. + The option to enable or disable aggregate push-down in V2 JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the related filters can be pushed down. If numPartitions equals to 1 or the group by key is the same as partitionColumn, Spark will push down aggregate to data source completely and not apply a final aggregate over the data source output. Otherwise, Spark will apply a final aggregate over the data source output. read @@ -250,7 +250,7 @@ logging into the data sources. pushDownLimit false - The option to enable or disable LIMIT push-down into V2 JDBC data source. The default value is false, in which case Spark does not push down LIMIT to the JDBC data source. Otherwise, if value sets to true, LIMIT is pushed down to the JDBC data source. SPARK still applies LIMIT on the result from data source even if LIMIT is pushed down. + The option to enable or disable LIMIT push-down into V2 JDBC data source. The LIMIT push-down also includes LIMIT + SORT , a.k.a. the Top N operator. The default value is false, in which case Spark does not push down LIMIT or LIMIT with SORT to the JDBC data source. Otherwise, if sets to true, LIMIT or LIMIT with SORT is pushed down to the JDBC data source. If numPartitions is greater than 1, SPARK still applies LIMIT or LIMIT with SORT on the result from data source even if LIMIT or LIMIT with SORT is pushed down. Otherwise, if LIMIT or LIMIT with SORT is pushed down and numPartitions equals to 1, SPARK will not apply LIMIT or LIMIT with SORT on the result from data source. read @@ -306,7 +306,7 @@ logging into the data sources. Note that kerberos authentication with keytab is not always supported by the JDBC driver.
        Before using keytab and principal configuration options, please make sure the following requirements are met: -* The included JDBC driver version supports kerberos authentication with keytab. +* The included JDBC driver version supports kerberos authentication with keytab. * There is a built-in connection provider which supports the used database. There is a built-in connection providers for the following databases: From 8c5860d9fe8207319d1b484a01d785166179113b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 13 Apr 2022 14:41:47 +0800 Subject: [PATCH 52/53] [SPARK-38855][SQL] DS V2 supports push down math functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Currently, Spark have some math functions of ANSI standard. Please refer https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L388 These functions show below: `LN`, `EXP`, `POWER`, `SQRT`, `FLOOR`, `CEIL`, `WIDTH_BUCKET` The mainstream databases support these functions show below. | 函数 | PostgreSQL | ClickHouse | H2 | MySQL | Oracle | Redshift | Presto | Teradata | Snowflake | DB2 | Vertica | Exasol | SqlServer | Yellowbrick | Impala | Mariadb | Druid | Pig | SQLite | Influxdata | Singlestore | ElasticSearch | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | `LN` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `EXP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `POWER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | | `SQRT` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `FLOOR` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `CEIL` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `WIDTH_BUCKET` | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | No | No | Yes | No | No | No | No | No | No | No | DS V2 should supports push down these math functions. ### Why are the changes needed? DS V2 supports push down math functions ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #36140 from beliefer/SPARK-38855. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../expressions/GeneralScalarExpression.java | 54 +++++++++++++++++++ .../util/V2ExpressionSQLBuilder.java | 7 +++ .../sql/errors/QueryCompilationErrors.scala | 4 ++ .../catalyst/util/V2ExpressionBuilder.scala | 28 +++++++++- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 26 +++++++++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 28 +++++++++- 6 files changed, 145 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 8952761f9ef34..58082d5ee09c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -94,6 +94,60 @@ *

      • Since version: 3.3.0
      • *
      *
    30. + *
    31. Name: ABS + *
        + *
      • SQL semantic: ABS(expr)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    32. + *
    33. Name: COALESCE + *
        + *
      • SQL semantic: COALESCE(expr1, expr2)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    34. + *
    35. Name: LN + *
        + *
      • SQL semantic: LN(expr)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    36. + *
    37. Name: EXP + *
        + *
      • SQL semantic: EXP(expr)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    38. + *
    39. Name: POWER + *
        + *
      • SQL semantic: POWER(expr, number)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    40. + *
    41. Name: SQRT + *
        + *
      • SQL semantic: SQRT(expr)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    42. + *
    43. Name: FLOOR + *
        + *
      • SQL semantic: FLOOR(expr)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    44. + *
    45. Name: CEIL + *
        + *
      • SQL semantic: CEIL(expr)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    46. + *
    47. Name: WIDTH_BUCKET + *
        + *
      • SQL semantic: WIDTH_BUCKET(expr)
      • + *
      • Since version: 3.3.0
      • + *
      + *
    48. *
    * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, * including: add, subtract, multiply, divide, remainder, pmod. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index a7d1ed7f85e84..c9dfa2003e3c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -95,6 +95,13 @@ public String build(Expression expr) { return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); case "ABS": case "COALESCE": + case "LN": + case "EXP": + case "POWER": + case "SQRT": + case "FLOOR": + case "CEIL": + case "WIDTH_BUCKET": return visitSQLFunction(name, Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 88c00c02597e7..0c7a1030fd434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2371,4 +2371,8 @@ object QueryCompilationErrors { messageParameters = Array(fieldName.quoted, path.quoted), origin = context) } + + def noSuchFunctionError(database: String, funcInfo: String): Throwable = { + new AnalysisException(s"$database does not support function: $funcInfo") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 9e852e7dd59fa..b9847d48b2e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn @@ -105,6 +105,32 @@ class V2ExpressionBuilder( } else { None } + case Log(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) + case Exp(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) + case Pow(left, right) => + val l = generateExpression(left) + val r = generateExpression(right) + if (l.isDefined && r.isDefined) { + Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get))) + } else { + None + } + case Sqrt(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) + case Floor(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) + case Ceil(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) + case wb: WidthBucket => + val childrenExpressions = wb.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == wb.children.length) { + Some(new GeneralScalarExpression("WIDTH_BUCKET", + childrenExpressions.toArray[V2Expression])) + } else { + None + } case and: And => // AND expects predicate val l = generateExpression(and.left, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 5f92f6dae9f11..6681aee778dbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -20,14 +20,40 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.errors.QueryCompilationErrors private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + funcName match { + case "WIDTH_BUCKET" => + val functionInfo = super.visitSQLFunction(funcName, inputs) + throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo) + case _ => super.visitSQLFunction(funcName, inputs) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val h2SQLBuilder = new H2SQLBuilder() + try { + Some(h2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index d99497822a2b8..94f044a0a6755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when} +import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -440,6 +440,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedInfo(df5, expectedPlanFragment5) checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true))) + + val df6 = spark.table("h2.test.employee") + .filter(ln($"dept") > 1) + .filter(exp($"salary") > 2000) + .filter(pow($"dept", 2) > 4) + .filter(sqrt($"salary") > 100) + .filter(floor($"dept") > 1) + .filter(ceil($"dept") > 1) + checkFiltersRemoved(df6, ansiMode) + val expectedPlanFragment6 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " + + "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...," + } else { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]" + } + checkPushedInfo(df6, expectedPlanFragment6) + checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true))) + + // H2 does not support width_bucket + val df7 = sql(""" + |SELECT * FROM h2.test.employee + |WHERE width_bucket(dept, 1, 6, 3) > 1 + |""".stripMargin) + checkFiltersRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]") + checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true))) } } } From 246101179698f71e84e50286b36907a68a837a1d Mon Sep 17 00:00:00 2001 From: chenzhx Date: Thu, 14 Apr 2022 12:00:45 +0800 Subject: [PATCH 53/53] update spark version to r61 --- assembly/pom.xml | 2 +- common/kvstore/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/avro/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10-token-provider/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- hadoop-cloud/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- repl/pom.xml | 2 +- resource-managers/kubernetes/core/pom.xml | 2 +- resource-managers/kubernetes/integration-tests/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 36 files changed, 36 insertions(+), 36 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 62888c64f7ceb..1c0a0e2a23786 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 046abd63ba1cf..cd4bca564165f 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index b833b1fcbc008..186ded95e30ac 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 12082c9b0dfe6..4453d4f3a39b3 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index a9accba20c2d7..08d185e39ec01 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index c3d8242e71752..a2ff08edf2762 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 38aa65c957e40..41536893ea953 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 725a378a13e8c..4ba2298503336 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index fcdabb29537bd..3265a4f8b0fe7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 25a1a9131e65f..d0f2f1d724f2b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/external/avro/pom.xml b/external/avro/pom.xml index bc12b51b13b5e..daddd8770efc0 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 282d64e3459c6..fb29aec8a3403 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 8ef608a021fe7..dea5a0a23c92a 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 649c7af28e92a..148a9625fed19 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kafka-0-10-token-provider/pom.xml b/external/kafka-0-10-token-provider/pom.xml index dfbaa18d4c698..882c194653c84 100644 --- a/external/kafka-0-10-token-provider/pom.xml +++ b/external/kafka-0-10-token-provider/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index eb5c9d4c9ca33..6774d254e80df 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 5d087dc9dd633..ab23f231fd978 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index e02048e28a2f7..55c892bdb7510 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index a8338509ad0e2..87259b3cd607e 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index da5bc3e8fcbce..5e2b40fd8d917 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 318921a298493..1206506f214e6 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 8d66de29b51bc..b609eaf181019 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index d9c22bf33e8e3..aae11098ae944 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index fcd7ade1810aa..796d55e7d7785 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/pom.xml b/pom.xml index 09de5d6f45ff7..25243d6e1132d 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 pom Spark Project Parent POM http://spark.apache.org/ diff --git a/repl/pom.xml b/repl/pom.xml index 36d9b0e5e43aa..714fdf9d0d8a5 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 77f4385e277a2..dcd4ceace7fad 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 1d12e2ebce1c7..95ea5e12c35bc 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 301462026b190..0c764d83c503a 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index db7e3e03107ec..d049e217637d3 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6c089f9feb3e3..631edbd8eb3e4 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 9afd9a3ef54b5..998de75018d4e 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index f1dcddd806525..dd3dabb82cc67 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 82bdeaf4e6608..e6bb5d5f49dd2 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 3a0f9a2f00c71..91db9435a87d4 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index c2b09a8508e2a..2d5830ad83d1c 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml