diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index c8d924db75aed..a7d1ed7f85e84 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -93,6 +93,10 @@ public String build(Expression expr) { return visitNot(build(e.children()[0])); case "~": return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); + case "ABS": + case "COALESCE": + return visitSQLFunction(name, + Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { List children = Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); @@ -210,6 +214,10 @@ protected String visitCaseWhen(String[] children) { return sb.toString(); } + protected String visitSQLFunction(String funcName, String[] inputs) { + return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; + } + protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { throw new IllegalArgumentException("Unexpected V2 expression: " + expr); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 5fd01ac5636b1..37db499470aa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn @@ -95,6 +95,15 @@ class V2ExpressionBuilder( } case Cast(child, dataType, _, true) => generateExpression(child).map(v => new V2Cast(v, dataType)) + case Abs(child, true) => generateExpression(child) + .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v))) + case Coalesce(children) => + val childrenExpressions = children.flatMap(generateExpression(_)) + if (children.length == childrenExpressions.length) { + Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression])) + } else { + None + } case and: And => // AND expects predicate val l = generateExpression(and.left, true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 67a02904660c3..4f7974999ba86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, sum, udf, when} +import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -381,19 +381,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2))) val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1) - checkFiltersRemoved(df2, ansiMode) - - df2.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = if (ansiMode) { - "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], " - } else { - "PushedFilters: [ID IS NOT NULL], " - } - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + val expectedPlanFragment2 = if (ansiMode) { + "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], " + } else { + "PushedFilters: [ID IS NOT NULL], " } - + checkPushedInfo(df2, expectedPlanFragment2) if (ansiMode) { val e = intercept[SparkException] { checkAnswer(df2, Seq.empty) @@ -422,22 +416,30 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df4 = spark.table("h2.test.employee") .filter(($"salary" > 1000d).and($"salary" < 12000d)) - checkFiltersRemoved(df4, ansiMode) - - df4.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = if (ansiMode) { - "PushedFilters: [SALARY IS NOT NULL, " + - "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], " - } else { - "PushedFilters: [SALARY IS NOT NULL], " - } - checkKeywordsExistsInExplain(df4, expected_plan_fragment) + val expectedPlanFragment4 = if (ansiMode) { + "PushedFilters: [SALARY IS NOT NULL, " + + "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], " + } else { + "PushedFilters: [SALARY IS NOT NULL], " } - + checkPushedInfo(df4, expectedPlanFragment4) checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + + val df5 = spark.table("h2.test.employee") + .filter(abs($"dept" - 3) > 1) + .filter(coalesce($"salary", $"bonus") > 2000) + checkFiltersRemoved(df5, ansiMode) + val expectedPlanFragment5 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, ABS(DEPT - 3) > 1, " + + "(COALESCE(CAST(SALARY AS double), BONUS)) > 2000.0]" + } else { + "PushedFilters: [DEPT IS NOT NULL]" + } + checkPushedInfo(df5, expectedPlanFragment5) + checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true), + Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true))) } } }