diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java deleted file mode 100644 index ebeee22a853cf..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.expressions; - -import java.io.Serializable; - -import org.apache.spark.annotation.Evolving; - -/** - * The general SQL string corresponding to expression. - * - * @since 3.3.0 - */ -@Evolving -public class GeneralSQLExpression implements Expression, Serializable { - private String sql; - - public GeneralSQLExpression(String sql) { - this.sql = sql; - } - - public String sql() { return sql; } - - @Override - public String toString() { return sql; } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java new file mode 100644 index 0000000000000..b3dd2cbfe3d7d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; +import java.util.Arrays; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; + +// scalastyle:off line.size.limit +/** + * The general representation of SQL scalar expressions, which contains the upper-cased + * expression name and all the children expressions. + *

+ * The currently supported SQL scalar expressions: + *

    + *
  1. Name: IS_NULL + * + *
  2. + *
  3. Name: IS_NOT_NULL + * + *
  4. + *
  5. Name: = + * + *
  6. + *
  7. Name: != + * + *
  8. + *
  9. Name: <> + * + *
  10. + *
  11. Name: <=> + * + *
  12. + *
  13. Name: < + * + *
  14. + *
  15. Name: <= + * + *
  16. + *
  17. Name: > + * + *
  18. + *
  19. Name: >= + * + *
  20. + *
  21. Name: + + * + *
  22. + *
  23. Name: - + * + *
  24. + *
  25. Name: * + * + *
  26. + *
  27. Name: / + * + *
  28. + *
  29. Name: % + * + *
  30. + *
  31. Name: & + * + *
  32. + *
  33. Name: | + * + *
  34. + *
  35. Name: ^ + * + *
  36. + *
  37. Name: AND + * + *
  38. + *
  39. Name: OR + * + *
  40. + *
  41. Name: NOT + * + *
  42. + *
  43. Name: ~ + * + *
  44. + *
  45. Name: CASE_WHEN + * + *
  46. + *
+ * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, + * including: add, subtract, multiply, divide, remainder, pmod. + * + * @since 3.3.0 + */ +// scalastyle:on line.size.limit +@Evolving +public class GeneralScalarExpression implements Expression, Serializable { + private String name; + private Expression[] children; + + public GeneralScalarExpression(String name, Expression[] children) { + this.name = name; + this.children = children; + } + + public String name() { return name; } + public Expression[] children() { return children; } + + @Override + public String toString() { + V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder(); + try { + return builder.build(this); + } catch (Throwable e) { + return name + "(" + + Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java new file mode 100644 index 0000000000000..0af0d88b0f622 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.util; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; +import org.apache.spark.sql.connector.expressions.LiteralValue; + +/** + * The builder to generate SQL from V2 expressions. + */ +public class V2ExpressionSQLBuilder { + public String build(Expression expr) { + if (expr instanceof LiteralValue) { + return visitLiteral((LiteralValue) expr); + } else if (expr instanceof FieldReference) { + return visitFieldReference((FieldReference) expr); + } else if (expr instanceof GeneralScalarExpression) { + GeneralScalarExpression e = (GeneralScalarExpression) expr; + String name = e.name(); + switch (name) { + case "IS_NULL": + return visitIsNull(build(e.children()[0])); + case "IS_NOT_NULL": + return visitIsNotNull(build(e.children()[0])); + case "=": + case "!=": + case "<=>": + case "<": + case "<=": + case ">": + case ">=": + return visitBinaryComparison(name, build(e.children()[0]), build(e.children()[1])); + case "+": + case "*": + case "/": + case "%": + case "&": + case "|": + case "^": + return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1])); + case "-": + if (e.children().length == 1) { + return visitUnaryArithmetic(name, build(e.children()[0])); + } else { + return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1])); + } + case "AND": + return visitAnd(name, build(e.children()[0]), build(e.children()[1])); + case "OR": + return visitOr(name, build(e.children()[0]), build(e.children()[1])); + case "NOT": + return visitNot(build(e.children()[0])); + case "~": + return visitUnaryArithmetic(name, build(e.children()[0])); + case "CASE_WHEN": + List children = new ArrayList<>(); + for (Expression child : e.children()) { + children.add(build(child)); + } + return visitCaseWhen(children.toArray(new String[e.children().length])); + // TODO supports other expressions + default: + return visitUnexpectedExpr(expr); + } + } else { + return visitUnexpectedExpr(expr); + } + } + + protected String visitLiteral(LiteralValue literalValue) { + return literalValue.toString(); + } + + protected String visitFieldReference(FieldReference fieldRef) { + return fieldRef.toString(); + } + + protected String visitIsNull(String v) { + return v + " IS NULL"; + } + + protected String visitIsNotNull(String v) { + return v + " IS NOT NULL"; + } + + protected String visitBinaryComparison(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitBinaryArithmetic(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitAnd(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitOr(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitNot(String v) { + return "NOT (" + v + ")"; + } + + protected String visitUnaryArithmetic(String name, String v) { return name +" (" + v + ")"; } + + protected String visitCaseWhen(String[] children) { + StringBuilder sb = new StringBuilder("CASE"); + for (int i = 0; i < children.length; i += 2) { + String c = children[i]; + int j = i + 1; + if (j < children.length) { + String v = children[j]; + sb.append(" WHEN "); + sb.append(c); + sb.append(" THEN "); + sb.append(v); + } else { + sb.append(" ELSE "); + sb.append(c); + } + } + sb.append(" END"); + return sb.toString(); + } + + protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { + throw new IllegalArgumentException("Unexpected V2 expression: " + expr); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala deleted file mode 100644 index 6239d0e2e7ae8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryOperator, CaseWhen, EqualTo, Expression, IsNotNull, IsNull, Literal, Not} -import org.apache.spark.sql.connector.expressions.LiteralValue - -/** - * The builder to generate SQL string from catalyst expressions. - */ -class ExpressionSQLBuilder(e: Expression) { - - def build(): Option[String] = generateSQL(e) - - private def generateSQL(expr: Expression): Option[String] = expr match { - case Literal(value, dataType) => Some(LiteralValue(value, dataType).toString) - case a: Attribute => Some(quoteIfNeeded(a.name)) - case IsNull(col) => generateSQL(col).map(c => s"$c IS NULL") - case IsNotNull(col) => generateSQL(col).map(c => s"$c IS NOT NULL") - case b: BinaryOperator => - val l = generateSQL(b.left) - val r = generateSQL(b.right) - if (l.isDefined && r.isDefined) { - Some(s"(${l.get}) ${b.sqlOperator} (${r.get})") - } else { - None - } - case Not(EqualTo(left, right)) => - val l = generateSQL(left) - val r = generateSQL(right) - if (l.isDefined && r.isDefined) { - Some(s"${l.get} != ${r.get}") - } else { - None - } - case Not(child) => generateSQL(child).map(v => s"NOT ($v)") - case CaseWhen(branches, elseValue) => - val conditionsSQL = branches.map(_._1).flatMap(generateSQL) - val valuesSQL = branches.map(_._2).flatMap(generateSQL) - if (conditionsSQL.length == branches.length && valuesSQL.length == branches.length) { - val branchSQL = - conditionsSQL.zip(valuesSQL).map { case (c, v) => s" WHEN $c THEN $v" }.mkString - if (elseValue.isDefined) { - elseValue.flatMap(generateSQL).map(v => s"CASE$branchSQL ELSE $v END") - } else { - Some(s"CASE$branchSQL END") - } - } else { - None - } - // TODO supports other expressions - case _ => None - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala new file mode 100644 index 0000000000000..1e361695056a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} + +/** + * The builder to generate V2 expressions from catalyst expressions. + */ +class V2ExpressionBuilder(e: Expression) { + + def build(): Option[V2Expression] = generateExpression(e) + + private def canTranslate(b: BinaryOperator) = b match { + case _: And | _: Or => true + case _: BinaryComparison => true + case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true + case add: Add => add.failOnError + case sub: Subtract => sub.failOnError + case mul: Multiply => mul.failOnError + case div: Divide => div.failOnError + case r: Remainder => r.failOnError + case _ => false + } + + private def generateExpression(expr: Expression): Option[V2Expression] = expr match { + case Literal(value, dataType) => Some(LiteralValue(value, dataType)) + case attr: Attribute => Some(FieldReference.column(attr.name)) + case IsNull(col) => generateExpression(col) + .map(c => new GeneralScalarExpression("IS_NULL", Array[V2Expression](c))) + case IsNotNull(col) => generateExpression(col) + .map(c => new GeneralScalarExpression("IS_NOT_NULL", Array[V2Expression](c))) + case b: BinaryOperator if canTranslate(b) => + val left = generateExpression(b.left) + val right = generateExpression(b.right) + if (left.isDefined && right.isDefined) { + Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](left.get, right.get))) + } else { + None + } + case Not(eq: EqualTo) => + val left = generateExpression(eq.left) + val right = generateExpression(eq.right) + if (left.isDefined && right.isDefined) { + Some(new GeneralScalarExpression("!=", Array[V2Expression](left.get, right.get))) + } else { + None + } + case Not(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("NOT", Array[V2Expression](v))) + case UnaryMinus(child, true) => generateExpression(child) + .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) + case BitwiseNot(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) + case CaseWhen(branches, elseValue) => + val conditions = branches.map(_._1).flatMap(generateExpression) + val values = branches.map(_._2).flatMap(generateExpression) + if (conditions.length == branches.length && values.length == branches.length) { + val branchExpressions = conditions.zip(values).flatMap { case (c, v) => + Seq[V2Expression](c, v) + } + if (elseValue.isDefined) { + elseValue.flatMap(generateExpression).map { v => + val children = (branchExpressions :+ v).toArray[V2Expression] + // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue] + new GeneralScalarExpression("CASE_WHEN", children) + } + } else { + // The children looks like [condition1, value1, ..., conditionN, valueN] + Some(new GeneralScalarExpression("CASE_WHEN", branchExpressions.toArray[V2Expression])) + } + } else { + None + } + // TODO supports other expressions + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a1602a3aa4880..c386655c947f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -38,10 +38,10 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.catalyst.util.ExpressionSQLBuilder +import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{Expression => ExpressionV2, FieldReference, GeneralSQLExpression, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -776,8 +776,8 @@ object DataSourceStrategy Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)) } - protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { - def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = { + def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match { case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => val directionV2 = directionV1 match { case Ascending => SortDirection.ASCENDING @@ -864,8 +864,8 @@ object PushableColumnWithoutNestedColumn extends PushableColumnBase { * Get the expression of DS V2 to represent catalyst expression that can be pushed down. */ object PushableExpression { - def unapply(e: Expression): Option[ExpressionV2] = e match { + def unapply(e: Expression): Option[V2Expression] = e match { case PushableColumnWithoutNestedColumn(name) => Some(FieldReference.column(name)) - case _ => new ExpressionSQLBuilder(e).build().map(new GeneralSQLExpression(_)) + case _ => new V2ExpressionBuilder(e).build() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 2d10bbf5de537..a7e0ec8b72a7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, Timesta import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{FieldReference, GeneralSQLExpression, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -194,6 +195,31 @@ abstract class JdbcDialect extends Serializable with Logging{ case _ => value } + class JDBCSQLBuilder extends V2ExpressionSQLBuilder { + override def visitFieldReference(fieldRef: FieldReference): String = { + if (fieldRef.fieldNames().length != 1) { + throw new IllegalArgumentException( + "FieldReference with field name has multiple or zero parts unsupported: " + fieldRef); + } + quoteIdentifier(fieldRef.fieldNames.head) + } + } + + /** + * Converts V2 expression to String representing a SQL expression. + * @param expr The V2 expression to be converted. + * @return Converted value. + */ + @Since("3.3.0") + def compileExpression(expr: Expression): Option[String] = { + val jdbcSQLBuilder = new JDBCSQLBuilder() + try { + Some(jdbcSQLBuilder.build(expr)) + } catch { + case _: IllegalArgumentException => None + } + } + /** * Converts aggregate function to String representing a SQL expression. * @param aggFunction The aggregate function to be converted. @@ -203,55 +229,20 @@ abstract class JdbcDialect extends Serializable with Logging{ def compileAggregate(aggFunction: AggregateFunc): Option[String] = { aggFunction match { case min: Min => - val sql = min.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } - Some(s"MIN($sql)") + compileExpression(min.column).map(v => s"MIN($v)") case max: Max => - val sql = max.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } - Some(s"MAX($sql)") + compileExpression(max.column).map(v => s"MAX($v)") case count: Count => - val sql = count.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } val distinct = if (count.isDistinct) "DISTINCT " else "" - Some(s"COUNT($distinct$sql)") + compileExpression(count.column).map(v => s"COUNT($distinct$v)") case sum: Sum => - val sql = sum.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } val distinct = if (sum.isDistinct) "DISTINCT " else "" - Some(s"SUM($distinct$sql)") + compileExpression(sum.column).map(v => s"SUM($distinct$v)") case _: CountStar => Some("COUNT(*)") case avg: Avg => - val sql = avg.column match { - case field: FieldReference => - if (field.fieldNames.length != 1) return None - quoteIdentifier(field.fieldNames.head) - case expr: GeneralSQLExpression => - expr.sql() - } val distinct = if (avg.isDistinct) "DISTINCT " else "" - Some(s"AVG($distinct$sql)") + compileExpression(avg.column).map(v => s"AVG($distinct$v)") case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 3edc4b9502064..fb1d049e8e2d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} @@ -97,7 +95,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): InMemoryTable = { + properties: java.util.Map[String, String]): InMemoryTable = { new InMemoryTable(name, schema, partitions, properties) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index a1463523d38ff..92a5c552108b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector -import java.util import java.util.Collections import test.org.apache.spark.sql.connector.catalog.functions._ @@ -37,7 +36,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { - private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String] private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = { catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index fd3c69eff5652..00d2a445ab47e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector import java.io.File -import java.util import java.util.OptionalLong import test.org.apache.spark.sql.connector._ @@ -552,7 +551,7 @@ abstract class SimpleBatchTable extends Table with SupportsRead { override def name(): String = this.getClass.toString - override def capabilities(): util.Set[TableCapability] = util.EnumSet.of(BATCH_READ) + override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ) } abstract class SimpleScanBuilder extends ScanBuilder @@ -575,7 +574,7 @@ trait TestingV2Source extends TableProvider { override def getTable( schema: StructType, partitioning: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { getTable(new CaseInsensitiveStringMap(properties)) } @@ -792,7 +791,7 @@ class SchemaRequiredDataSource extends TableProvider { override def getTable( schema: StructType, partitioning: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala index 094667001b6c3..e3d61a846fdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} @@ -61,7 +59,7 @@ class TestLocalScanCatalog extends BasicInMemoryTableCatalog { ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val table = new TestLocalScanTable(ident.toString) tables.put(ident, table) table @@ -76,8 +74,8 @@ object TestLocalScanTable { class TestLocalScanTable(override val name: String) extends Table with SupportsRead { override def schema(): StructType = TestLocalScanTable.schema - override def capabilities(): util.Set[TableCapability] = - util.EnumSet.of(TableCapability.BATCH_READ) + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new TestLocalScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 99c322a7155f2..64c893ed74fdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util import scala.collection.JavaConverters._ @@ -138,8 +137,8 @@ class SimpleWritableDataSource extends TestingV2Source { new MyWriteBuilder(path, info) } - override def capabilities(): util.Set[TableCapability] = - util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE) + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE) } override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index a12065ec0ab2a..5f2e0b28aeccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} @@ -215,8 +213,8 @@ private case object TestRelation extends LeafNode with NamedRelation { private case class CapabilityTable(_capabilities: TableCapability*) extends Table { override def name(): String = "capability_test_table" override def schema(): StructType = TableCapabilityCheckSuite.schema - override def capabilities(): util.Set[TableCapability] = { - val set = util.EnumSet.noneOf(classOf[TableCapability]) + override def capabilities(): java.util.Set[TableCapability] = { + val set = java.util.EnumSet.noneOf(classOf[TableCapability]) _capabilities.foreach(set.add) set } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index bf2749d1afc53..0a0aaa8021996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector -import java.util import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean @@ -35,7 +34,7 @@ import org.apache.spark.sql.types.StructType */ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends DelegatingCatalogExtension { - protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() + protected val tables: java.util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() private val tableCreated: AtomicBoolean = new AtomicBoolean(false) @@ -48,7 +47,7 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): T + properties: java.util.Map[String, String]): T override def loadTable(ident: Identifier): Table = { if (tables.containsKey(ident)) { @@ -69,12 +68,12 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY val propsWithLocation = if (properties.containsKey(key)) { // Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified. if (!properties.containsKey(TableCatalog.PROP_LOCATION)) { - val newProps = new util.HashMap[String, String]() + val newProps = new java.util.HashMap[String, String]() newProps.putAll(properties) newProps.put(TableCatalog.PROP_LOCATION, "file:/abc") newProps diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index ff1bd29808637..c5be222645b19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext} import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} @@ -104,7 +102,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { // To simplify the test implementation, only support fixed schema. if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) { throw new UnsupportedOperationException @@ -129,8 +127,8 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp override def schema(): StructType = V1ReadFallbackCatalog.schema - override def capabilities(): util.Set[TableCapability] = { - util.EnumSet.of(TableCapability.BATCH_READ) + override def capabilities(): java.util.Set[TableCapability] = { + java.util.EnumSet.of(TableCapability.BATCH_READ) } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 9fbaf7890f8f8..992c46cc6cdb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -223,7 +221,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { + properties: java.util.Map[String, String]): InMemoryTableWithV1Fallback = { val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) InMemoryV1Provider.tables.put(name, t) tables.put(Identifier.of(Array("default"), name), t) @@ -321,7 +319,7 @@ class InMemoryTableWithV1Fallback( override val name: String, override val schema: StructType, override val partitioning: Array[Transform], - override val properties: util.Map[String, String]) + override val properties: java.util.Map[String, String]) extends Table with SupportsWrite with SupportsRead { @@ -331,7 +329,7 @@ class InMemoryTableWithV1Fallback( } } - override def capabilities: util.Set[TableCapability] = util.EnumSet.of( + override def capabilities: java.util.Set[TableCapability] = java.util.EnumSet.of( TableCapability.BATCH_READ, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index aa0289ae75bdb..3f90fb47efb28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager} import java.util.Properties -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{avg, count, lit, sum, udf} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -841,6 +842,34 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Row(2, 2, 2, 2, 2, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 3, 0d))) } + test("scan with aggregate push-down: aggregate function with binary arithmetic") { + Seq(false, true).foreach { ansiMode => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") + checkAggregateRemoved(df, ansiMode) + val expected_plan_fragment = if (ansiMode) { + "PushedAggregates: [SUM((2147483647) + (DEPT))], " + + "PushedFilters: [], PushedGroupByColumns: []" + } else { + "PushedFilters: []" + } + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + if (ansiMode) { + val e = intercept[SparkException] { + checkAnswer(df, Seq(Row(-10737418233L))) + } + assert(e.getMessage.contains( + "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\"")) + } else { + checkAnswer(df, Seq(Row(-10737418233L))) + } + } + } + } + test("scan with aggregate push-down: aggregate function with UDF") { val df = spark.table("h2.test.employee") val decrease = udf { (x: Double, y: Double) => x - y }