diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 91dae749f974b..1df01d29cbdd1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -75,7 +75,7 @@ public String build(Expression expr) { name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); case "-": if (e.children().length == 1) { - return visitUnaryArithmetic(name, build(e.children()[0])); + return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); } else { return visitBinaryArithmetic( name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); @@ -87,7 +87,7 @@ public String build(Expression expr) { case "NOT": return visitNot(build(e.children()[0])); case "~": - return visitUnaryArithmetic(name, build(e.children()[0])); + return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); case "CASE_WHEN": { List children = Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); @@ -179,7 +179,7 @@ protected String visitNot(String v) { return "NOT (" + v + ")"; } - protected String visitUnaryArithmetic(String name, String v) { return name +" (" + v + ")"; } + protected String visitUnaryArithmetic(String name, String v) { return name + v; } protected String visitCaseWhen(String[] children) { StringBuilder sb = new StringBuilder("CASE"); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 5c8e6a67ce3f0..fbd6884358b0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -94,6 +94,7 @@ class V2ExpressionBuilder( None } case and: And => + // AND expects predicate val l = generateExpression(and.left, true) val r = generateExpression(and.right, true) if (l.isDefined && r.isDefined) { @@ -103,6 +104,7 @@ class V2ExpressionBuilder( None } case or: Or => + // OR expects predicate val l = generateExpression(or.left, true) val r = generateExpression(or.right, true) if (l.isDefined && r.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 4b28de26b59e4..674ef005df2dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -229,9 +229,8 @@ abstract class JdbcDialect extends Serializable with Logging{ override def visitNamedReference(namedRef: NamedReference): String = { if (namedRef.fieldNames().length > 1) { - throw new IllegalArgumentException( - QueryCompilationErrors.commandNotSupportNestedColumnError( - "Filter push down", namedRef.toString).getMessage); + throw QueryCompilationErrors.commandNotSupportNestedColumnError( + "Filter push down", namedRef.toString) } quoteIdentifier(namedRef.fieldNames.head) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index d50a0551226a9..d6f098f1d5189 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -402,14 +402,38 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2))) - val df2 = sql(""" + val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1) + + checkFiltersRemoved(df2, ansiMode) + + df2.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = if (ansiMode) { + "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], " + } else { + "PushedFilters: [ID IS NOT NULL], " + } + checkKeywordsExistsInExplain(df2, expected_plan_fragment) + } + + if (ansiMode) { + val e = intercept[SparkException] { + checkAnswer(df2, Seq.empty) + } + assert(e.getMessage.contains( + "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\"")) + } else { + checkAnswer(df2, Seq.empty) + } + + val df3 = sql(""" |SELECT * FROM h2.test.employee |WHERE (CASE WHEN SALARY > 10000 THEN BONUS ELSE BONUS + 200 END) > 1200 |""".stripMargin) - checkFiltersRemoved(df2, ansiMode) + checkFiltersRemoved(df3, ansiMode) - df2.queryExecution.optimizedPlan.collect { + df3.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = if (ansiMode) { "PushedFilters: [(CASE WHEN SALARY > 10000.00 THEN BONUS" + @@ -417,10 +441,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } else { "PushedFilters: []" } - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + checkKeywordsExistsInExplain(df3, expected_plan_fragment) } - checkAnswer(df2, + checkAnswer(df3, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) } }