diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java
deleted file mode 100644
index ebeee22a853cf..0000000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralSQLExpression.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.connector.expressions;
-
-import java.io.Serializable;
-
-import org.apache.spark.annotation.Evolving;
-
-/**
- * The general SQL string corresponding to expression.
- *
- * @since 3.3.0
- */
-@Evolving
-public class GeneralSQLExpression implements Expression, Serializable {
- private String sql;
-
- public GeneralSQLExpression(String sql) {
- this.sql = sql;
- }
-
- public String sql() { return sql; }
-
- @Override
- public String toString() { return sql; }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
new file mode 100644
index 0000000000000..b3dd2cbfe3d7d
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder;
+
+// scalastyle:off line.size.limit
+/**
+ * The general representation of SQL scalar expressions, which contains the upper-cased
+ * expression name and all the children expressions.
+ *
+ * The currently supported SQL scalar expressions:
+ *
+ * - Name:
IS_NULL
+ *
+ * - SQL semantic:
expr IS NULL
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
IS_NOT_NULL
+ *
+ * - SQL semantic:
expr IS NOT NULL
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
=
+ *
+ * - SQL semantic:
expr1 = expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
!=
+ *
+ * - SQL semantic:
expr1 != expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
<>
+ *
+ * - SQL semantic:
expr1 <> expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
<=>
+ *
+ * - SQL semantic:
expr1 <=> expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
<
+ *
+ * - SQL semantic:
expr1 < expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
<=
+ *
+ * - SQL semantic:
expr1 <= expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
>
+ *
+ * - SQL semantic:
expr1 > expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
>=
+ *
+ * - SQL semantic:
expr1 >= expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
+
+ *
+ * - SQL semantic:
expr1 + expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
-
+ *
+ * - SQL semantic:
expr1 - expr2 or - expr
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
*
+ *
+ * - SQL semantic:
expr1 * expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
/
+ *
+ * - SQL semantic:
expr1 / expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
%
+ *
+ * - SQL semantic:
expr1 % expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
&
+ *
+ * - SQL semantic:
expr1 & expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
|
+ *
+ * - SQL semantic:
expr1 | expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
^
+ *
+ * - SQL semantic:
expr1 ^ expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
AND
+ *
+ * - SQL semantic:
expr1 AND expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
OR
+ *
+ * - SQL semantic:
expr1 OR expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
NOT
+ *
+ * - SQL semantic:
NOT expr
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
~
+ *
+ * - SQL semantic:
~ expr
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
CASE_WHEN
+ *
+ * - SQL semantic:
+ *
CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END
+ *
+ * - Since version: 3.3.0
+ *
+ *
+ *
+ * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off,
+ * including: add, subtract, multiply, divide, remainder, pmod.
+ *
+ * @since 3.3.0
+ */
+// scalastyle:on line.size.limit
+@Evolving
+public class GeneralScalarExpression implements Expression, Serializable {
+ private String name;
+ private Expression[] children;
+
+ public GeneralScalarExpression(String name, Expression[] children) {
+ this.name = name;
+ this.children = children;
+ }
+
+ public String name() { return name; }
+ public Expression[] children() { return children; }
+
+ @Override
+ public String toString() {
+ V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder();
+ try {
+ return builder.build(this);
+ } catch (Throwable e) {
+ return name + "(" +
+ Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")";
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
new file mode 100644
index 0000000000000..0af0d88b0f622
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.util;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.connector.expressions.FieldReference;
+import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
+import org.apache.spark.sql.connector.expressions.LiteralValue;
+
+/**
+ * The builder to generate SQL from V2 expressions.
+ */
+public class V2ExpressionSQLBuilder {
+ public String build(Expression expr) {
+ if (expr instanceof LiteralValue) {
+ return visitLiteral((LiteralValue) expr);
+ } else if (expr instanceof FieldReference) {
+ return visitFieldReference((FieldReference) expr);
+ } else if (expr instanceof GeneralScalarExpression) {
+ GeneralScalarExpression e = (GeneralScalarExpression) expr;
+ String name = e.name();
+ switch (name) {
+ case "IS_NULL":
+ return visitIsNull(build(e.children()[0]));
+ case "IS_NOT_NULL":
+ return visitIsNotNull(build(e.children()[0]));
+ case "=":
+ case "!=":
+ case "<=>":
+ case "<":
+ case "<=":
+ case ">":
+ case ">=":
+ return visitBinaryComparison(name, build(e.children()[0]), build(e.children()[1]));
+ case "+":
+ case "*":
+ case "/":
+ case "%":
+ case "&":
+ case "|":
+ case "^":
+ return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1]));
+ case "-":
+ if (e.children().length == 1) {
+ return visitUnaryArithmetic(name, build(e.children()[0]));
+ } else {
+ return visitBinaryArithmetic(name, build(e.children()[0]), build(e.children()[1]));
+ }
+ case "AND":
+ return visitAnd(name, build(e.children()[0]), build(e.children()[1]));
+ case "OR":
+ return visitOr(name, build(e.children()[0]), build(e.children()[1]));
+ case "NOT":
+ return visitNot(build(e.children()[0]));
+ case "~":
+ return visitUnaryArithmetic(name, build(e.children()[0]));
+ case "CASE_WHEN":
+ List children = new ArrayList<>();
+ for (Expression child : e.children()) {
+ children.add(build(child));
+ }
+ return visitCaseWhen(children.toArray(new String[e.children().length]));
+ // TODO supports other expressions
+ default:
+ return visitUnexpectedExpr(expr);
+ }
+ } else {
+ return visitUnexpectedExpr(expr);
+ }
+ }
+
+ protected String visitLiteral(LiteralValue literalValue) {
+ return literalValue.toString();
+ }
+
+ protected String visitFieldReference(FieldReference fieldRef) {
+ return fieldRef.toString();
+ }
+
+ protected String visitIsNull(String v) {
+ return v + " IS NULL";
+ }
+
+ protected String visitIsNotNull(String v) {
+ return v + " IS NOT NULL";
+ }
+
+ protected String visitBinaryComparison(String name, String l, String r) {
+ return "(" + l + ") " + name + " (" + r + ")";
+ }
+
+ protected String visitBinaryArithmetic(String name, String l, String r) {
+ return "(" + l + ") " + name + " (" + r + ")";
+ }
+
+ protected String visitAnd(String name, String l, String r) {
+ return "(" + l + ") " + name + " (" + r + ")";
+ }
+
+ protected String visitOr(String name, String l, String r) {
+ return "(" + l + ") " + name + " (" + r + ")";
+ }
+
+ protected String visitNot(String v) {
+ return "NOT (" + v + ")";
+ }
+
+ protected String visitUnaryArithmetic(String name, String v) { return name +" (" + v + ")"; }
+
+ protected String visitCaseWhen(String[] children) {
+ StringBuilder sb = new StringBuilder("CASE");
+ for (int i = 0; i < children.length; i += 2) {
+ String c = children[i];
+ int j = i + 1;
+ if (j < children.length) {
+ String v = children[j];
+ sb.append(" WHEN ");
+ sb.append(c);
+ sb.append(" THEN ");
+ sb.append(v);
+ } else {
+ sb.append(" ELSE ");
+ sb.append(c);
+ }
+ }
+ sb.append(" END");
+ return sb.toString();
+ }
+
+ protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
+ throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala
deleted file mode 100644
index 6239d0e2e7ae8..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/ExpressionSQLBuilder.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.util
-
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryOperator, CaseWhen, EqualTo, Expression, IsNotNull, IsNull, Literal, Not}
-import org.apache.spark.sql.connector.expressions.LiteralValue
-
-/**
- * The builder to generate SQL string from catalyst expressions.
- */
-class ExpressionSQLBuilder(e: Expression) {
-
- def build(): Option[String] = generateSQL(e)
-
- private def generateSQL(expr: Expression): Option[String] = expr match {
- case Literal(value, dataType) => Some(LiteralValue(value, dataType).toString)
- case a: Attribute => Some(quoteIfNeeded(a.name))
- case IsNull(col) => generateSQL(col).map(c => s"$c IS NULL")
- case IsNotNull(col) => generateSQL(col).map(c => s"$c IS NOT NULL")
- case b: BinaryOperator =>
- val l = generateSQL(b.left)
- val r = generateSQL(b.right)
- if (l.isDefined && r.isDefined) {
- Some(s"(${l.get}) ${b.sqlOperator} (${r.get})")
- } else {
- None
- }
- case Not(EqualTo(left, right)) =>
- val l = generateSQL(left)
- val r = generateSQL(right)
- if (l.isDefined && r.isDefined) {
- Some(s"${l.get} != ${r.get}")
- } else {
- None
- }
- case Not(child) => generateSQL(child).map(v => s"NOT ($v)")
- case CaseWhen(branches, elseValue) =>
- val conditionsSQL = branches.map(_._1).flatMap(generateSQL)
- val valuesSQL = branches.map(_._2).flatMap(generateSQL)
- if (conditionsSQL.length == branches.length && valuesSQL.length == branches.length) {
- val branchSQL =
- conditionsSQL.zip(valuesSQL).map { case (c, v) => s" WHEN $c THEN $v" }.mkString
- if (elseValue.isDefined) {
- elseValue.flatMap(generateSQL).map(v => s"CASE$branchSQL ELSE $v END")
- } else {
- Some(s"CASE$branchSQL END")
- }
- } else {
- None
- }
- // TODO supports other expressions
- case _ => None
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
new file mode 100644
index 0000000000000..1e361695056a7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal, Multiply, Not, Or, Remainder, Subtract, UnaryMinus}
+import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
+
+/**
+ * The builder to generate V2 expressions from catalyst expressions.
+ */
+class V2ExpressionBuilder(e: Expression) {
+
+ def build(): Option[V2Expression] = generateExpression(e)
+
+ private def canTranslate(b: BinaryOperator) = b match {
+ case _: And | _: Or => true
+ case _: BinaryComparison => true
+ case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true
+ case add: Add => add.failOnError
+ case sub: Subtract => sub.failOnError
+ case mul: Multiply => mul.failOnError
+ case div: Divide => div.failOnError
+ case r: Remainder => r.failOnError
+ case _ => false
+ }
+
+ private def generateExpression(expr: Expression): Option[V2Expression] = expr match {
+ case Literal(value, dataType) => Some(LiteralValue(value, dataType))
+ case attr: Attribute => Some(FieldReference.column(attr.name))
+ case IsNull(col) => generateExpression(col)
+ .map(c => new GeneralScalarExpression("IS_NULL", Array[V2Expression](c)))
+ case IsNotNull(col) => generateExpression(col)
+ .map(c => new GeneralScalarExpression("IS_NOT_NULL", Array[V2Expression](c)))
+ case b: BinaryOperator if canTranslate(b) =>
+ val left = generateExpression(b.left)
+ val right = generateExpression(b.right)
+ if (left.isDefined && right.isDefined) {
+ Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](left.get, right.get)))
+ } else {
+ None
+ }
+ case Not(eq: EqualTo) =>
+ val left = generateExpression(eq.left)
+ val right = generateExpression(eq.right)
+ if (left.isDefined && right.isDefined) {
+ Some(new GeneralScalarExpression("!=", Array[V2Expression](left.get, right.get)))
+ } else {
+ None
+ }
+ case Not(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("NOT", Array[V2Expression](v)))
+ case UnaryMinus(child, true) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("-", Array[V2Expression](v)))
+ case BitwiseNot(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("~", Array[V2Expression](v)))
+ case CaseWhen(branches, elseValue) =>
+ val conditions = branches.map(_._1).flatMap(generateExpression)
+ val values = branches.map(_._2).flatMap(generateExpression)
+ if (conditions.length == branches.length && values.length == branches.length) {
+ val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
+ Seq[V2Expression](c, v)
+ }
+ if (elseValue.isDefined) {
+ elseValue.flatMap(generateExpression).map { v =>
+ val children = (branchExpressions :+ v).toArray[V2Expression]
+ // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue]
+ new GeneralScalarExpression("CASE_WHEN", children)
+ }
+ } else {
+ // The children looks like [condition1, value1, ..., conditionN, valueN]
+ Some(new GeneralScalarExpression("CASE_WHEN", branchExpressions.toArray[V2Expression]))
+ }
+ } else {
+ None
+ }
+ // TODO supports other expressions
+ case _ => None
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index a1602a3aa4880..c386655c947f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -38,10 +38,10 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
-import org.apache.spark.sql.catalyst.util.ExpressionSQLBuilder
+import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
-import org.apache.spark.sql.connector.expressions.{Expression => ExpressionV2, FieldReference, GeneralSQLExpression, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
+import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
@@ -776,8 +776,8 @@ object DataSourceStrategy
Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray))
}
- protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
- def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match {
+ protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = {
+ def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match {
case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
@@ -864,8 +864,8 @@ object PushableColumnWithoutNestedColumn extends PushableColumnBase {
* Get the expression of DS V2 to represent catalyst expression that can be pushed down.
*/
object PushableExpression {
- def unapply(e: Expression): Option[ExpressionV2] = e match {
+ def unapply(e: Expression): Option[V2Expression] = e match {
case PushableColumnWithoutNestedColumn(name) => Some(FieldReference.column(name))
- case _ => new ExpressionSQLBuilder(e).build().map(new GeneralSQLExpression(_))
+ case _ => new V2ExpressionBuilder(e).build()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 2d10bbf5de537..a7e0ec8b72a7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -32,8 +32,9 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, Timesta
import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.index.TableIndex
-import org.apache.spark.sql.connector.expressions.{FieldReference, GeneralSQLExpression, NamedReference}
+import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
+import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -194,6 +195,31 @@ abstract class JdbcDialect extends Serializable with Logging{
case _ => value
}
+ class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
+ override def visitFieldReference(fieldRef: FieldReference): String = {
+ if (fieldRef.fieldNames().length != 1) {
+ throw new IllegalArgumentException(
+ "FieldReference with field name has multiple or zero parts unsupported: " + fieldRef);
+ }
+ quoteIdentifier(fieldRef.fieldNames.head)
+ }
+ }
+
+ /**
+ * Converts V2 expression to String representing a SQL expression.
+ * @param expr The V2 expression to be converted.
+ * @return Converted value.
+ */
+ @Since("3.3.0")
+ def compileExpression(expr: Expression): Option[String] = {
+ val jdbcSQLBuilder = new JDBCSQLBuilder()
+ try {
+ Some(jdbcSQLBuilder.build(expr))
+ } catch {
+ case _: IllegalArgumentException => None
+ }
+ }
+
/**
* Converts aggregate function to String representing a SQL expression.
* @param aggFunction The aggregate function to be converted.
@@ -203,55 +229,20 @@ abstract class JdbcDialect extends Serializable with Logging{
def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
aggFunction match {
case min: Min =>
- val sql = min.column match {
- case field: FieldReference =>
- if (field.fieldNames.length != 1) return None
- quoteIdentifier(field.fieldNames.head)
- case expr: GeneralSQLExpression =>
- expr.sql()
- }
- Some(s"MIN($sql)")
+ compileExpression(min.column).map(v => s"MIN($v)")
case max: Max =>
- val sql = max.column match {
- case field: FieldReference =>
- if (field.fieldNames.length != 1) return None
- quoteIdentifier(field.fieldNames.head)
- case expr: GeneralSQLExpression =>
- expr.sql()
- }
- Some(s"MAX($sql)")
+ compileExpression(max.column).map(v => s"MAX($v)")
case count: Count =>
- val sql = count.column match {
- case field: FieldReference =>
- if (field.fieldNames.length != 1) return None
- quoteIdentifier(field.fieldNames.head)
- case expr: GeneralSQLExpression =>
- expr.sql()
- }
val distinct = if (count.isDistinct) "DISTINCT " else ""
- Some(s"COUNT($distinct$sql)")
+ compileExpression(count.column).map(v => s"COUNT($distinct$v)")
case sum: Sum =>
- val sql = sum.column match {
- case field: FieldReference =>
- if (field.fieldNames.length != 1) return None
- quoteIdentifier(field.fieldNames.head)
- case expr: GeneralSQLExpression =>
- expr.sql()
- }
val distinct = if (sum.isDistinct) "DISTINCT " else ""
- Some(s"SUM($distinct$sql)")
+ compileExpression(sum.column).map(v => s"SUM($distinct$v)")
case _: CountStar =>
Some("COUNT(*)")
case avg: Avg =>
- val sql = avg.column match {
- case field: FieldReference =>
- if (field.fieldNames.length != 1) return None
- quoteIdentifier(field.fieldNames.head)
- case expr: GeneralSQLExpression =>
- expr.sql()
- }
val distinct = if (avg.isDistinct) "DISTINCT " else ""
- Some(s"AVG($distinct$sql)")
+ compileExpression(avg.column).map(v => s"AVG($distinct$v)")
case _ => None
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
index 3edc4b9502064..fb1d049e8e2d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode}
@@ -97,7 +95,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): InMemoryTable = {
+ properties: java.util.Map[String, String]): InMemoryTable = {
new InMemoryTable(name, schema, partitions, properties)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index a1463523d38ff..92a5c552108b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
import java.util.Collections
import test.org.apache.spark.sql.connector.catalog.functions._
@@ -37,7 +36,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
- private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
+ private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String]
private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = {
catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index fd3c69eff5652..00d2a445ab47e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.connector
import java.io.File
-import java.util
import java.util.OptionalLong
import test.org.apache.spark.sql.connector._
@@ -552,7 +551,7 @@ abstract class SimpleBatchTable extends Table with SupportsRead {
override def name(): String = this.getClass.toString
- override def capabilities(): util.Set[TableCapability] = util.EnumSet.of(BATCH_READ)
+ override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ)
}
abstract class SimpleScanBuilder extends ScanBuilder
@@ -575,7 +574,7 @@ trait TestingV2Source extends TableProvider {
override def getTable(
schema: StructType,
partitioning: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
getTable(new CaseInsensitiveStringMap(properties))
}
@@ -792,7 +791,7 @@ class SchemaRequiredDataSource extends TableProvider {
override def getTable(
schema: StructType,
partitioning: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val userGivenSchema = schema
new SimpleBatchTable {
override def schema(): StructType = userGivenSchema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
index 094667001b6c3..e3d61a846fdb4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability}
@@ -61,7 +59,7 @@ class TestLocalScanCatalog extends BasicInMemoryTableCatalog {
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val table = new TestLocalScanTable(ident.toString)
tables.put(ident, table)
table
@@ -76,8 +74,8 @@ object TestLocalScanTable {
class TestLocalScanTable(override val name: String) extends Table with SupportsRead {
override def schema(): StructType = TestLocalScanTable.schema
- override def capabilities(): util.Set[TableCapability] =
- util.EnumSet.of(TableCapability.BATCH_READ)
+ override def capabilities(): java.util.Set[TableCapability] =
+ java.util.EnumSet.of(TableCapability.BATCH_READ)
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new TestLocalScanBuilder
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
index 99c322a7155f2..64c893ed74fdb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.connector
import java.io.{BufferedReader, InputStreamReader, IOException}
-import java.util
import scala.collection.JavaConverters._
@@ -138,8 +137,8 @@ class SimpleWritableDataSource extends TestingV2Source {
new MyWriteBuilder(path, info)
}
- override def capabilities(): util.Set[TableCapability] =
- util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE)
+ override def capabilities(): java.util.Set[TableCapability] =
+ java.util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE)
}
override def getTable(options: CaseInsensitiveStringMap): Table = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
index a12065ec0ab2a..5f2e0b28aeccc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal}
@@ -215,8 +213,8 @@ private case object TestRelation extends LeafNode with NamedRelation {
private case class CapabilityTable(_capabilities: TableCapability*) extends Table {
override def name(): String = "capability_test_table"
override def schema(): StructType = TableCapabilityCheckSuite.schema
- override def capabilities(): util.Set[TableCapability] = {
- val set = util.EnumSet.noneOf(classOf[TableCapability])
+ override def capabilities(): java.util.Set[TableCapability] = {
+ val set = java.util.EnumSet.noneOf(classOf[TableCapability])
_capabilities.foreach(set.add)
set
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
index bf2749d1afc53..0a0aaa8021996 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
@@ -35,7 +34,7 @@ import org.apache.spark.sql.types.StructType
*/
private[connector] trait TestV2SessionCatalogBase[T <: Table] extends DelegatingCatalogExtension {
- protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]()
+ protected val tables: java.util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]()
private val tableCreated: AtomicBoolean = new AtomicBoolean(false)
@@ -48,7 +47,7 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): T
+ properties: java.util.Map[String, String]): T
override def loadTable(ident: Identifier): Table = {
if (tables.containsKey(ident)) {
@@ -69,12 +68,12 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY
val propsWithLocation = if (properties.containsKey(key)) {
// Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified.
if (!properties.containsKey(TableCatalog.PROP_LOCATION)) {
- val newProps = new util.HashMap[String, String]()
+ val newProps = new java.util.HashMap[String, String]()
newProps.putAll(properties)
newProps.put(TableCatalog.PROP_LOCATION, "file:/abc")
newProps
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
index ff1bd29808637..c5be222645b19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability}
@@ -104,7 +102,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog {
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
// To simplify the test implementation, only support fixed schema.
if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) {
throw new UnsupportedOperationException
@@ -129,8 +127,8 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp
override def schema(): StructType = V1ReadFallbackCatalog.schema
- override def capabilities(): util.Set[TableCapability] = {
- util.EnumSet.of(TableCapability.BATCH_READ)
+ override def capabilities(): java.util.Set[TableCapability] = {
+ java.util.EnumSet.of(TableCapability.BATCH_READ)
}
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index 9fbaf7890f8f8..992c46cc6cdb1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -223,7 +221,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): InMemoryTableWithV1Fallback = {
+ properties: java.util.Map[String, String]): InMemoryTableWithV1Fallback = {
val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties)
InMemoryV1Provider.tables.put(name, t)
tables.put(Identifier.of(Array("default"), name), t)
@@ -321,7 +319,7 @@ class InMemoryTableWithV1Fallback(
override val name: String,
override val schema: StructType,
override val partitioning: Array[Transform],
- override val properties: util.Map[String, String])
+ override val properties: java.util.Map[String, String])
extends Table
with SupportsWrite with SupportsRead {
@@ -331,7 +329,7 @@ class InMemoryTableWithV1Fallback(
}
}
- override def capabilities: util.Set[TableCapability] = util.EnumSet.of(
+ override def capabilities: java.util.Set[TableCapability] = java.util.EnumSet.of(
TableCapability.BATCH_READ,
TableCapability.V1_BATCH_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index aa0289ae75bdb..3f90fb47efb28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.jdbc
import java.sql.{Connection, DriverManager}
import java.util.Properties
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
@@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering,
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.{avg, count, lit, sum, udf}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@@ -841,6 +842,34 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
Row(2, 2, 2, 2, 2, 0d, 12000d, 0d, 12000d, 12000d, 0d, 0d, 3, 0d)))
}
+ test("scan with aggregate push-down: aggregate function with binary arithmetic") {
+ Seq(false, true).foreach { ansiMode =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+ val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee")
+ checkAggregateRemoved(df, ansiMode)
+ val expected_plan_fragment = if (ansiMode) {
+ "PushedAggregates: [SUM((2147483647) + (DEPT))], " +
+ "PushedFilters: [], PushedGroupByColumns: []"
+ } else {
+ "PushedFilters: []"
+ }
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ if (ansiMode) {
+ val e = intercept[SparkException] {
+ checkAnswer(df, Seq(Row(-10737418233L)))
+ }
+ assert(e.getMessage.contains(
+ "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\""))
+ } else {
+ checkAnswer(df, Seq(Row(-10737418233L)))
+ }
+ }
+ }
+ }
+
test("scan with aggregate push-down: aggregate function with UDF") {
val df = spark.table("h2.test.employee")
val decrease = udf { (x: Double, y: Double) => x - y }