From 0a6035298981bd78785afce7e058e4a1eb512d6f Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 25 Aug 2020 19:22:50 -0700 Subject: [PATCH 01/17] SPARK-24994: Unwrap casts for integral types --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/UnwrapCast.scala | 155 +++++++++++++++++ .../catalyst/optimizer/UnwrapCastSuite.scala | 109 ++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 157 ++++++++++++++++++ 4 files changed, 422 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 296fe86e834e5..5b2df94c22968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -107,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery, EliminateSerialization, RemoveRedundantAliases, + UnwrapCast, RemoveNoopOperators, CombineWithFields, SimplifyExtractValueOps, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala new file mode 100644 index 0000000000000..b3a94ac635cec --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + +/** + * Unwrap casts in binary comparison operations with the following pattern: + * `BinaryComparison(Cast(fromExp, _), Literal(value, toType))` + * This rule optimize expressions with this pattern by either replacing the cast with simpler + * constructs, or moving the cast from the expression side to the literal side, so they can be + * optimized away and pushed down to data sources. + * + * Currently this only handles the case where `fromType` (of `fromExp`) and `toType` are of + * integral types (i.e., byte, short, int and long). It checks to see if the literal `value` is + * within range (min, max) of the `fromType`. If this is true then it means we can safely cast the + * `value` to the `fromType` and thus able to move the cast to the literal side. Otherwise, it + * replaces the cast with different simpler constructs, such as + * `EqualTo(fromExp, Literal(max, fromType)` when input is + * `GreaterThanOrEqualTo(Cast(fromExp, fromType), Literal(max, toType))` + */ +object UnwrapCast extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case l: LogicalPlan => l transformExpressionsUp { + case e @ BinaryComparison(_, _) => unwrapCast(e) + } + } + + private def unwrapCast(exp: Expression): Expression = exp match { + case BinaryComparison(Literal(_, _), Cast(_, _, _)) => + def swap(e: Expression): Expression = e match { + case GreaterThan(left, right) => LessThan(right, left) + case GreaterThanOrEqual(left, right) => LessThanOrEqual(right, left) + case EqualTo(left, right) => EqualTo(right, left) + case EqualNullSafe(left, right) => EqualNullSafe(right, left) + case LessThanOrEqual(left, right) => GreaterThanOrEqual(right, left) + case LessThan(left, right) => GreaterThan(right, left) + case other => other + } + swap(unwrapCast(swap(exp))) + + case BinaryComparison(Cast(fromExp, _, _), Literal(value, toType)) => + val fromType = fromExp.dataType + if (!fromType.isInstanceOf[IntegralType] || !toType.isInstanceOf[IntegralType] + || !Cast.canUpCast(fromType, toType)) { + return exp + } + + // Check if the literal value is within the range of the `fromType`, and handle the boundary + // cases in the following + val toIntegralType = toType.asInstanceOf[IntegralType] + val (min, max) = getRange(fromType) + val (minInToType, maxInToType) = + (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval()) + + // Compare upper bounds + val maxCmp = toIntegralType.ordering.asInstanceOf[Ordering[Any]].compare(value, maxInToType) + if (maxCmp > 0) { + exp match { + case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) => + return falseIfNotNull(fromExp) + case LessThan(_, _) | LessThanOrEqual(_, _) => + return trueIfNotNull(fromExp) + case EqualNullSafe(_, _) => + return FalseLiteral + case _ => return exp // impossible but safe guard, same below + } + } else if (maxCmp == 0) { + exp match { + case GreaterThan(_, _) => + return falseIfNotNull(fromExp) + case LessThanOrEqual(_, _) => + return trueIfNotNull(fromExp) + case LessThan(_, _) => + return Not(EqualTo(fromExp, Literal(max, fromType))) + case GreaterThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => + return EqualTo(fromExp, Literal(max, fromType)) + case _ => return exp + } + } + + // Compare lower bounds + val minCmp = toIntegralType.ordering.asInstanceOf[Ordering[Any]].compare(value, minInToType) + if (minCmp < 0) { + exp match { + case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => + return trueIfNotNull(fromExp) + case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) => + return falseIfNotNull(fromExp) + case EqualNullSafe(_, _) => + return FalseLiteral + case _ => return exp + } + } else if (minCmp == 0) { + exp match { + case LessThan(_, _) => + return falseIfNotNull(fromExp) + case GreaterThanOrEqual(_, _) => + return trueIfNotNull(fromExp) + case GreaterThan(_, _) => + return Not(EqualTo(fromExp, Literal(min, fromType))) + case LessThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => + return EqualTo(fromExp, Literal(min, fromType)) + case _ => return exp + } + } + + // Now we can assume `value` is within the bound of the source type, e.g., min < value < max + val lit = Cast(Literal(value), fromType) + exp match { + case GreaterThan(_, _) => GreaterThan(fromExp, lit) + case GreaterThanOrEqual(_, _) => GreaterThanOrEqual(fromExp, lit) + case EqualTo(_, _) => EqualTo(fromExp, lit) + case EqualNullSafe(_, _) => EqualNullSafe(fromExp, lit) + case LessThan(_, _) => LessThan(fromExp, lit) + case LessThanOrEqual(_, _) => LessThanOrEqual(fromExp, lit) + } + + case _ => exp + } + + private[sql] def falseIfNotNull(e: Expression): Expression = + And(IsNull(e), Literal(null, BooleanType)) + + private[sql] def trueIfNotNull(e: Expression): Expression = + Or(IsNotNull(e), Literal(null, BooleanType)) + + private def getRange(ty: DataType): (Any, Any) = ty match { + case ByteType => (Byte.MinValue, Byte.MaxValue) + case ShortType => (Short.MinValue, Short.MaxValue) + case IntegerType => (Int.MinValue, Int.MaxValue) + case LongType => (Long.MinValue, Long.MaxValue) + } +} + + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala new file mode 100644 index 0000000000000..17667367e0d46 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class UnwrapCastSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches: List[Batch] = + Batch("Unwrap casts", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyCasts, + UnwrapCast) :: Nil + } + + val testRelation: LocalRelation = LocalRelation('a.short) + + test("unwrap casts when literal == max") { + val v = Short.MaxValue + assertEquivalent('a > v.toInt, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a >= v.toInt, 'a === v) + assertEquivalent('a === v.toInt, 'a === v) + assertEquivalent('a <=> v.toInt, 'a === v) + assertEquivalent('a <= v.toInt, UnwrapCast.trueIfNotNull('a)) + assertEquivalent('a < v.toInt, 'a =!= v) + } + + test("unwrap casts when literal > max") { + val v: Int = Short.MaxValue + 100 + assertEquivalent('a > v, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a >= v, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a === v, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a <=> v, false) + assertEquivalent('a <= v, UnwrapCast.trueIfNotNull('a)) + assertEquivalent('a < v, UnwrapCast.trueIfNotNull('a)) + } + + test("unwrap casts when literal == min") { + val v = Short.MinValue + assertEquivalent('a > v.toInt, 'a =!= v) + assertEquivalent('a >= v.toInt, UnwrapCast.trueIfNotNull('a)) + assertEquivalent('a === v.toInt, 'a === v) + assertEquivalent('a <=> v.toInt, 'a === v) + assertEquivalent('a <= v.toInt, 'a === v) + assertEquivalent('a < v.toInt, UnwrapCast.falseIfNotNull('a)) + } + + test("unwrap casts when literal < min") { + val v: Int = Short.MinValue - 100 + assertEquivalent('a > v, UnwrapCast.trueIfNotNull('a)) + assertEquivalent('a >= v, UnwrapCast.trueIfNotNull('a)) + assertEquivalent('a === v, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a <=> v, false) + assertEquivalent('a <= v, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a < v, UnwrapCast.falseIfNotNull('a)) + } + + test("unwrap casts when literal is within range (min, max)") { + assertEquivalent('a > 300, 'a > 300.toShort) + assertEquivalent('a >= 500, 'a >= 500.toShort) + assertEquivalent('a === 32766, 'a === 32766.toShort) + assertEquivalent('a <=> 32766, 'a <=> 32766.toShort) + assertEquivalent('a <= -6000, 'a <= -6000.toShort) + assertEquivalent('a < -32767, 'a < -32767.toShort) + } + + test("unwrap casts when cast is on rhs") { + val v = Short.MaxValue + assertEquivalent(Literal(v.toInt) < 'a, UnwrapCast.falseIfNotNull('a)) + assertEquivalent(Literal(v.toInt) <= 'a, Literal(v) === 'a) + assertEquivalent(Literal(v.toInt) === 'a, Literal(v) === 'a) + assertEquivalent(Literal(v.toInt) <=> 'a, Literal(v) === 'a) + assertEquivalent(Literal(v.toInt) >= 'a, UnwrapCast.trueIfNotNull('a)) + assertEquivalent(Literal(v.toInt) > 'a, 'a =!= v) + + assertEquivalent(Literal(30) <= 'a, Literal(30.toShort) <= 'a) + } + + private def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val plan = testRelation.where(e1).analyze + val actual = Optimize.execute(plan) + val expected = testRelation.where(e2).analyze + comparePlans(actual, expected) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b86df4db816b3..739131f3bcc32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3691,6 +3691,163 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("SELECT id FROM t WHERE (SELECT true)"), Row(0L)) } } + + test("test casts pushdown on orc/parquet for integral types") { + def checkPushedFilters( + format: String, + df: DataFrame, + filters: Array[sources.Filter], + noScan: Boolean = false): Unit = { + val scanExec = df.queryExecution.sparkPlan + .find(_.isInstanceOf[BatchScanExec]) + if (noScan) { + assert(scanExec.isEmpty) + return + } + val scan = scanExec.get.asInstanceOf[BatchScanExec].scan + format match { + case "orc" => + assert(scan.isInstanceOf[OrcScan]) + assert(scan.asInstanceOf[OrcScan].pushedFilters === filters) + case "parquet" => + assert(scan.isInstanceOf[ParquetScan]) + assert(scan.asInstanceOf[ParquetScan].pushedFilters === filters) + case _ => + fail(s"unknown format $format") + } + } + + Seq("orc", "parquet").foreach { format => + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempPath { dir => + spark.range(100).map(i => (i.toShort, i.toString)).toDF("id", "s") + .write + .format(format) + .save(dir.getCanonicalPath) + val df = spark.read.format(format).load(dir.getCanonicalPath) + + // cases when value == MAX + var v = Short.MaxValue + checkPushedFilters( + format, + df.where('id > v.toInt), + Array(), + noScan = true) + checkPushedFilters( + format, + df.where('id >= v.toInt), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters( + format, + df.where('id === v.toInt), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters( + format, + df.where('id <=> v.toInt), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters( + format, + df.where('id <= v.toInt), + Array(sources.IsNotNull("id"))) + checkPushedFilters( + format, + df.where('id < v.toInt), + Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) + + // cases when value > MAX + var v1: Int = Short.MaxValue.toInt + 100 + checkPushedFilters( + format, + df.where('id > v1), + Array(), + noScan = true) + checkPushedFilters( + format, + df.where('id >= v1), + Array(), + noScan = true) + checkPushedFilters( + format, + df.where('id === v1), + Array(), + noScan = true) + checkPushedFilters( + format, + df.where('id <=> v1), + Array(), + noScan = true) + checkPushedFilters( + format, + df.where('id <= v1), + Array(sources.IsNotNull("id"))) + checkPushedFilters( + format, + df.where('id < v1), + Array(sources.IsNotNull("id"))) + + // cases when value = MIN + v = Short.MinValue + checkPushedFilters( + format, + df.where(lit(v.toInt) < 'id), + Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) + checkPushedFilters( + format, + df.where(lit(v.toInt) <= 'id), + Array(sources.IsNotNull("id"))) + checkPushedFilters( + format, + df.where(lit(v.toInt) === 'id), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters( + format, + df.where(lit(v.toInt) >= 'id), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters( + format, + df.where(lit(v.toInt) > 'id), + Array(), + noScan = true) + + // cases when value < MIN + v1 = Short.MinValue.toInt - 100 + checkPushedFilters( + format, + df.where(lit(v1) < 'id), + Array(sources.IsNotNull("id"))) + checkPushedFilters( + format, + df.where(lit(v1) <= 'id), + Array(sources.IsNotNull("id"))) + checkPushedFilters( + format, + df.where(lit(v1) === 'id), + Array(), + noScan = true) + checkPushedFilters( + format, + df.where(lit(v1) >= 'id), + Array(), + noScan = true) + checkPushedFilters( + format, + df.where(lit(v1) > 'id), + Array(), + noScan = true) + + // cases when value is within range (MIN, MAX) + checkPushedFilters( + format, + df.where('id > 30), + Array(sources.IsNotNull("id"), sources.GreaterThan("id", 30))) + checkPushedFilters( + format, + df.where(lit(100) >= 'id), + Array(sources.IsNotNull("id"), sources.LessThanOrEqual("id", 100))) + } + } + } + } } case class Foo(bar: Option[String]) From b7ee52bfa437927ce1fc658895ba934c433565a6 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Fri, 28 Aug 2020 00:50:59 -0700 Subject: [PATCH 02/17] Address comments --- .../sql/catalyst/optimizer/UnwrapCast.scala | 215 +++++++++++------- .../catalyst/optimizer/UnwrapCastSuite.scala | 26 ++- 2 files changed, 155 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala index b3a94ac635cec..74c53db18efc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala @@ -24,19 +24,46 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ /** - * Unwrap casts in binary comparison operations with the following pattern: - * `BinaryComparison(Cast(fromExp, _), Literal(value, toType))` - * This rule optimize expressions with this pattern by either replacing the cast with simpler - * constructs, or moving the cast from the expression side to the literal side, so they can be - * optimized away and pushed down to data sources. + * Unwrap casts in binary comparison operations with patterns like following: * - * Currently this only handles the case where `fromType` (of `fromExp`) and `toType` are of - * integral types (i.e., byte, short, int and long). It checks to see if the literal `value` is - * within range (min, max) of the `fromType`. If this is true then it means we can safely cast the - * `value` to the `fromType` and thus able to move the cast to the literal side. Otherwise, it - * replaces the cast with different simpler constructs, such as - * `EqualTo(fromExp, Literal(max, fromType)` when input is - * `GreaterThanOrEqualTo(Cast(fromExp, fromType), Literal(max, toType))` + * `BinaryComparison(Cast(fromExp, toType), Literal(value, toType))` + * or + * `BinaryComparison(Literal(value, toType), Cast(fromExp, toType))` + * + * This rule optimizes expressions with the above pattern by either replacing the cast with simpler + * constructs, or moving the cast from the expression side to the literal side, which enables them + * to be optimized away later and pushed down to data sources. + * + * Currently this only handles cases where `fromType` (of `fromExp`) and `toType` are of integral + * types (i.e., byte, short, int and long). The rule checks to see if the literal `value` is + * within range `(min, max)`, where `min` and `max` are the minimum and maximum value of + * `fromType`, respectively. If this is true then it means we can safely cast `value` to `fromType` + * and thus able to move the cast to the literal side. + * + * If the `value` is not within range `(min, max)`, the rule breaks the scenario into different + * cases and try to replace each with simpler constructs. + * + * if `value > max`, the cases are of following: + * - `cast(exp, ty) > value` ==> if(isnull(exp), null, false) + * - `cast(exp, ty) >= value` ==> if(isnull(exp), null, false) + * - `cast(exp, ty) === value` ==> if(isnull(exp), null, false) + * - `cast(exp, ty) <=> value` ==> false + * - `cast(exp, ty) <= value` ==> if(isnull(exp), null, true) + * - `cast(exp, ty) < value` ==> if(isnull(exp), null, true) + * + * if `value == max`, the cases are of following: + * - `cast(exp, ty) > value` ==> if(isnull(exp), null, false) + * - `cast(exp, ty) >= value` ==> exp == max + * - `cast(exp, ty) === value` ==> exp == max + * - `cast(exp, ty) <=> value` ==> exp == max + * - `cast(exp, ty) <= value` ==> if(isnull(exp), null, true) + * - `cast(exp, ty) < value` ==> exp =!= max + * + * Similarly for the cases when `value == min` and `value < min`. + * + * Further, the above `if(isnull(exp), null, false)` is represented using conjunction + * `and(isnull(exp), null)`, to enable further optimization and filter pushdown to data sources. + * Similarly, `if(isnull(exp), null, true)` is represented with `or(isnotnull(exp), null)`. */ object UnwrapCast extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -47,6 +74,9 @@ object UnwrapCast extends Rule[LogicalPlan] { private def unwrapCast(exp: Expression): Expression = exp match { case BinaryComparison(Literal(_, _), Cast(_, _, _)) => + // Not a canonical form. In this case we first canonicalize the expression by swapping the + // literal and cast side, then process the result and swap the literal and cast again to + // restore the original order. def swap(e: Expression): Expression = e match { case GreaterThan(left, right) => LessThan(right, left) case GreaterThanOrEqual(left, right) => LessThanOrEqual(right, left) @@ -54,77 +84,89 @@ object UnwrapCast extends Rule[LogicalPlan] { case EqualNullSafe(left, right) => EqualNullSafe(right, left) case LessThanOrEqual(left, right) => GreaterThanOrEqual(right, left) case LessThan(left, right) => GreaterThan(right, left) - case other => other + case _ => e } + swap(unwrapCast(swap(exp))) - case BinaryComparison(Cast(fromExp, _, _), Literal(value, toType)) => - val fromType = fromExp.dataType - if (!fromType.isInstanceOf[IntegralType] || !toType.isInstanceOf[IntegralType] - || !Cast.canUpCast(fromType, toType)) { - return exp - } + case BinaryComparison(Cast(fromExp, _, _), Literal(value, toType)) + if canImplicitlyCast(fromExp, toType) => - // Check if the literal value is within the range of the `fromType`, and handle the boundary - // cases in the following - val toIntegralType = toType.asInstanceOf[IntegralType] - val (min, max) = getRange(fromType) - val (minInToType, maxInToType) = - (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval()) - - // Compare upper bounds - val maxCmp = toIntegralType.ordering.asInstanceOf[Ordering[Any]].compare(value, maxInToType) - if (maxCmp > 0) { - exp match { - case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) => - return falseIfNotNull(fromExp) - case LessThan(_, _) | LessThanOrEqual(_, _) => - return trueIfNotNull(fromExp) - case EqualNullSafe(_, _) => - return FalseLiteral - case _ => return exp // impossible but safe guard, same below - } - } else if (maxCmp == 0) { - exp match { - case GreaterThan(_, _) => - return falseIfNotNull(fromExp) - case LessThanOrEqual(_, _) => - return trueIfNotNull(fromExp) - case LessThan(_, _) => - return Not(EqualTo(fromExp, Literal(max, fromType))) - case GreaterThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => - return EqualTo(fromExp, Literal(max, fromType)) - case _ => return exp - } - } + // In case both sides have integral type, optimize the comparison by removing casts or + // moving cast to the literal side. + simplifyIntegral(exp, fromExp, toType.asInstanceOf[IntegralType], value) - // Compare lower bounds - val minCmp = toIntegralType.ordering.asInstanceOf[Ordering[Any]].compare(value, minInToType) - if (minCmp < 0) { - exp match { - case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => - return trueIfNotNull(fromExp) - case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) => - return falseIfNotNull(fromExp) - case EqualNullSafe(_, _) => - return FalseLiteral - case _ => return exp - } - } else if (minCmp == 0) { - exp match { - case LessThan(_, _) => - return falseIfNotNull(fromExp) - case GreaterThanOrEqual(_, _) => - return trueIfNotNull(fromExp) - case GreaterThan(_, _) => - return Not(EqualTo(fromExp, Literal(min, fromType))) - case LessThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => - return EqualTo(fromExp, Literal(min, fromType)) - case _ => return exp - } - } + case _ => exp + } + + /** + * Check if the input `value` is within range `(min, max)` of the `fromType`, where `min` and + * `max` are the minimum and maximum value of the `fromType`. If the above is true, this + * optimizes the expression by moving the cast to the literal side. Otherwise if result is not + * true, this replaces the input binary comparison `exp` with simpler expressions. + */ + private def simplifyIntegral( + exp: Expression, + fromExp: Expression, + toType: IntegralType, + value: Any): Expression = { - // Now we can assume `value` is within the bound of the source type, e.g., min < value < max + val fromType = fromExp.dataType + val (min, max) = getRange(fromType) + val (minInToType, maxInToType) = { + (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval()) + } + val ordering = toType.ordering.asInstanceOf[Ordering[Any]] + val minCmp = ordering.compare(value, minInToType) + val maxCmp = ordering.compare(value, maxInToType) + + if (maxCmp > 0) { + exp match { + case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) => + falseIfNotNull(fromExp) + case LessThan(_, _) | LessThanOrEqual(_, _) => + trueIfNotNull(fromExp) + case EqualNullSafe(_, _) => + FalseLiteral + case _ => exp // impossible but safe guard, same below + } + } else if (maxCmp == 0) { + exp match { + case GreaterThan(_, _) => + falseIfNotNull(fromExp) + case LessThanOrEqual(_, _) => + trueIfNotNull(fromExp) + case LessThan(_, _) => + Not(EqualTo(fromExp, Literal(max, fromType))) + case GreaterThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => + EqualTo(fromExp, Literal(max, fromType)) + case _ => exp + } + } else if (minCmp < 0) { + exp match { + case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => + trueIfNotNull(fromExp) + case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) => + falseIfNotNull(fromExp) + case EqualNullSafe(_, _) => + FalseLiteral + case _ => exp + } + } else if (minCmp == 0) { + exp match { + case LessThan(_, _) => + falseIfNotNull(fromExp) + case GreaterThanOrEqual(_, _) => + trueIfNotNull(fromExp) + case GreaterThan(_, _) => + Not(EqualTo(fromExp, Literal(min, fromType))) + case LessThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => + EqualTo(fromExp, Literal(min, fromType)) + case _ => exp + } + } else { + // This means `value` is within range `(min, max)`. Optimize this by moving the cast to the + // literal side. val lit = Cast(Literal(value), fromType) exp match { case GreaterThan(_, _) => GreaterThan(fromExp, lit) @@ -134,13 +176,30 @@ object UnwrapCast extends Rule[LogicalPlan] { case LessThan(_, _) => LessThan(fromExp, lit) case LessThanOrEqual(_, _) => LessThanOrEqual(fromExp, lit) } + } + } - case _ => exp + /** + * Check if the input `fromExp` can be safely cast to `toType` without any loss of precision, + * i.e., the conversion is injective. Note this only handles the case when both sides are of + * integral type. + */ + private def canImplicitlyCast(fromExp: Expression, toType: DataType): Boolean = { + fromExp.dataType.isInstanceOf[IntegralType] && toType.isInstanceOf[IntegralType] && + Cast.canUpCast(fromExp.dataType, toType) } + /** + * Wraps input expression `e` with `if(isnull(e), null, false)`. The if-clause is represented + * using `and(isnull(e), null)` which is semantically equivalent by applying 3-valued logic. + */ private[sql] def falseIfNotNull(e: Expression): Expression = And(IsNull(e), Literal(null, BooleanType)) + /** + * Wraps input expression `e` with `if(isnull(e), null, true)`. The if-clause is represented + * using `or(isnotnull(e), null)` which is semantically equivalent by applying 3-valued logic. + */ private[sql] def trueIfNotNull(e: Expression): Expression = Or(IsNotNull(e), Literal(null, BooleanType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala index 17667367e0d46..707838342c4aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala @@ -23,21 +23,16 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.DoubleType class UnwrapCastSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches: List[Batch] = - Batch("Unwrap casts", FixedPoint(10), - NullPropagation, - ConstantFolding, - BooleanSimplification, - SimplifyConditionals, - SimplifyCasts, - UnwrapCast) :: Nil + Batch("Unwrap casts", FixedPoint(10), ConstantFolding, UnwrapCast) :: Nil } - val testRelation: LocalRelation = LocalRelation('a.short) + val testRelation: LocalRelation = LocalRelation('a.short, 'b.float) test("unwrap casts when literal == max") { val v = Short.MaxValue @@ -100,6 +95,21 @@ class UnwrapCastSuite extends PlanTest { assertEquivalent(Literal(30) <= 'a, Literal(30.toShort) <= 'a) } + test("unwrap cast should have no effect when input is not integral type") { + assertEquivalent('b > 42.0, Cast('b, DoubleType) > 42.0) + assertEquivalent('b >= 42.0, Cast('b, DoubleType) >= 42.0) + assertEquivalent('b === 42.0, Cast('b, DoubleType) === 42.0) + assertEquivalent('b <=> 42.0, Cast('b, DoubleType) <=> 42.0) + assertEquivalent('b <= 42.0, Cast('b, DoubleType) <= 42.0) + assertEquivalent('b < 42.0, Cast('b, DoubleType) < 42.0) + assertEquivalent(Literal(42.0) > 'b, Literal(42.0) > Cast('b, DoubleType)) + assertEquivalent(Literal(42.0) >= 'b, Literal(42.0) >= Cast('b, DoubleType)) + assertEquivalent(Literal(42.0) === 'b, Literal(42.0) === Cast('b, DoubleType)) + assertEquivalent(Literal(42.0) <=> 'b, Literal(42.0) <=> Cast('b, DoubleType)) + assertEquivalent(Literal(42.0) <= 'b, Literal(42.0) <= Cast('b, DoubleType)) + assertEquivalent(Literal(42.0) < 'b, Literal(42.0) < Cast('b, DoubleType)) + } + private def assertEquivalent(e1: Expression, e2: Expression): Unit = { val plan = testRelation.where(e1).analyze val actual = Optimize.execute(plan) From ed37d4919cb2f3eae38dc858d10806b7e1e52370 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Sat, 29 Aug 2020 22:35:45 -0700 Subject: [PATCH 03/17] Address comments --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- ...ala => UnwrapCastInBinaryComparison.scala} | 68 ++++----- ...> UnwrapCastInBinaryComparisonSuite.scala} | 56 +++++--- .../spark/sql/FileBasedDataSourceSuite.scala | 88 +++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 131 ------------------ 5 files changed, 157 insertions(+), 188 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/{UnwrapCast.scala => UnwrapCastInBinaryComparison.scala} (85%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{UnwrapCastSuite.scala => UnwrapCastInBinaryComparisonSuite.scala} (68%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5b2df94c22968..9216ab1631e7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -107,7 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery, EliminateSerialization, RemoveRedundantAliases, - UnwrapCast, + UnwrapCastInBinaryComparison, RemoveNoopOperators, CombineWithFields, SimplifyExtractValueOps, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala similarity index 85% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 74c53db18efc0..331bb9215a099 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -65,11 +65,12 @@ import org.apache.spark.sql.types._ * `and(isnull(exp), null)`, to enable further optimization and filter pushdown to data sources. * Similarly, `if(isnull(exp), null, true)` is represented with `or(isnotnull(exp), null)`. */ -object UnwrapCast extends Rule[LogicalPlan] { +object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case l: LogicalPlan => l transformExpressionsUp { - case e @ BinaryComparison(_, _) => unwrapCast(e) - } + case l: LogicalPlan => + l transformExpressionsUp { + case e @ BinaryComparison(_, _) => unwrapCast(e) + } } private def unwrapCast(exp: Expression): Expression = exp match { @@ -89,12 +90,11 @@ object UnwrapCast extends Rule[LogicalPlan] { swap(unwrapCast(swap(exp))) - case BinaryComparison(Cast(fromExp, _, _), Literal(value, toType)) - if canImplicitlyCast(fromExp, toType) => - + case BinaryComparison(Cast(fromExp, _, _), Literal(value, toType: IntegralType)) + if canImplicitlyCast(fromExp, toType) => // In case both sides have integral type, optimize the comparison by removing casts or // moving cast to the literal side. - simplifyIntegral(exp, fromExp, toType.asInstanceOf[IntegralType], value) + simplifyIntegral(exp, fromExp, toType, value) case _ => exp } @@ -123,9 +123,10 @@ object UnwrapCast extends Rule[LogicalPlan] { if (maxCmp > 0) { exp match { case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) => - falseIfNotNull(fromExp) + fromExp.falseIfNotNull + // falseIfNotNull(fromExp) case LessThan(_, _) | LessThanOrEqual(_, _) => - trueIfNotNull(fromExp) + fromExp.trueIfNotNull case EqualNullSafe(_, _) => FalseLiteral case _ => exp // impossible but safe guard, same below @@ -133,9 +134,9 @@ object UnwrapCast extends Rule[LogicalPlan] { } else if (maxCmp == 0) { exp match { case GreaterThan(_, _) => - falseIfNotNull(fromExp) + fromExp.falseIfNotNull case LessThanOrEqual(_, _) => - trueIfNotNull(fromExp) + fromExp.trueIfNotNull case LessThan(_, _) => Not(EqualTo(fromExp, Literal(max, fromType))) case GreaterThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => @@ -145,9 +146,9 @@ object UnwrapCast extends Rule[LogicalPlan] { } else if (minCmp < 0) { exp match { case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => - trueIfNotNull(fromExp) + fromExp.trueIfNotNull case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) => - falseIfNotNull(fromExp) + fromExp.falseIfNotNull case EqualNullSafe(_, _) => FalseLiteral case _ => exp @@ -155,9 +156,9 @@ object UnwrapCast extends Rule[LogicalPlan] { } else if (minCmp == 0) { exp match { case LessThan(_, _) => - falseIfNotNull(fromExp) + fromExp.falseIfNotNull case GreaterThanOrEqual(_, _) => - trueIfNotNull(fromExp) + fromExp.trueIfNotNull case GreaterThan(_, _) => Not(EqualTo(fromExp, Literal(min, fromType))) case LessThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => @@ -175,6 +176,7 @@ object UnwrapCast extends Rule[LogicalPlan] { case EqualNullSafe(_, _) => EqualNullSafe(fromExp, lit) case LessThan(_, _) => LessThan(fromExp, lit) case LessThanOrEqual(_, _) => LessThanOrEqual(fromExp, lit) + case _ => exp } } } @@ -186,29 +188,27 @@ object UnwrapCast extends Rule[LogicalPlan] { */ private def canImplicitlyCast(fromExp: Expression, toType: DataType): Boolean = { fromExp.dataType.isInstanceOf[IntegralType] && toType.isInstanceOf[IntegralType] && - Cast.canUpCast(fromExp.dataType, toType) + Cast.canUpCast(fromExp.dataType, toType) } - /** - * Wraps input expression `e` with `if(isnull(e), null, false)`. The if-clause is represented - * using `and(isnull(e), null)` which is semantically equivalent by applying 3-valued logic. - */ - private[sql] def falseIfNotNull(e: Expression): Expression = - And(IsNull(e), Literal(null, BooleanType)) - - /** - * Wraps input expression `e` with `if(isnull(e), null, true)`. The if-clause is represented - * using `or(isnotnull(e), null)` which is semantically equivalent by applying 3-valued logic. - */ - private[sql] def trueIfNotNull(e: Expression): Expression = - Or(IsNotNull(e), Literal(null, BooleanType)) - - private def getRange(ty: DataType): (Any, Any) = ty match { + private def getRange(dt: DataType): (Any, Any) = dt match { case ByteType => (Byte.MinValue, Byte.MaxValue) case ShortType => (Short.MinValue, Short.MaxValue) case IntegerType => (Int.MinValue, Int.MaxValue) case LongType => (Long.MinValue, Long.MaxValue) } -} - + private[optimizer] implicit class ExpressionWrapper(e: Expression) { + /** + * Wraps input expression `e` with `if(isnull(e), null, false)`. The if-clause is represented + * using `and(isnull(e), null)` which is semantically equivalent by applying 3-valued logic. + */ + def falseIfNotNull: Expression = And(IsNull(e), Literal(null, BooleanType)) + + /** + * Wraps input expression `e` with `if(isnull(e), null, true)`. The if-clause is represented + * using `or(isnotnull(e), null)` which is semantically equivalent by applying 3-valued logic. + */ + def trueIfNotNull: Expression = Or(IsNotNull(e), Literal(null, BooleanType)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala similarity index 68% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 707838342c4aa..8606eef737816 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -20,58 +20,61 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils._ +import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType} -class UnwrapCastSuite extends PlanTest { +class UnwrapCastInBinaryComparisonSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches: List[Batch] = - Batch("Unwrap casts", FixedPoint(10), ConstantFolding, UnwrapCast) :: Nil + Batch("Unwrap casts in binary comparison", FixedPoint(10), + NullPropagation, ConstantFolding, UnwrapCastInBinaryComparison) :: Nil } val testRelation: LocalRelation = LocalRelation('a.short, 'b.float) test("unwrap casts when literal == max") { val v = Short.MaxValue - assertEquivalent('a > v.toInt, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a > v.toInt, 'a.attr.falseIfNotNull) assertEquivalent('a >= v.toInt, 'a === v) assertEquivalent('a === v.toInt, 'a === v) assertEquivalent('a <=> v.toInt, 'a === v) - assertEquivalent('a <= v.toInt, UnwrapCast.trueIfNotNull('a)) + assertEquivalent('a <= v.toInt, 'a.attr.trueIfNotNull) assertEquivalent('a < v.toInt, 'a =!= v) } test("unwrap casts when literal > max") { - val v: Int = Short.MaxValue + 100 - assertEquivalent('a > v, UnwrapCast.falseIfNotNull('a)) - assertEquivalent('a >= v, UnwrapCast.falseIfNotNull('a)) - assertEquivalent('a === v, UnwrapCast.falseIfNotNull('a)) + val v: Int = positiveInt + assertEquivalent('a > v, 'a.attr.falseIfNotNull) + assertEquivalent('a >= v, 'a.attr.falseIfNotNull) + assertEquivalent('a === v, 'a.attr.falseIfNotNull) assertEquivalent('a <=> v, false) - assertEquivalent('a <= v, UnwrapCast.trueIfNotNull('a)) - assertEquivalent('a < v, UnwrapCast.trueIfNotNull('a)) + assertEquivalent('a <= v, 'a.attr.trueIfNotNull) + assertEquivalent('a < v, 'a.attr.trueIfNotNull) } test("unwrap casts when literal == min") { val v = Short.MinValue assertEquivalent('a > v.toInt, 'a =!= v) - assertEquivalent('a >= v.toInt, UnwrapCast.trueIfNotNull('a)) + assertEquivalent('a >= v.toInt, 'a.attr.trueIfNotNull) assertEquivalent('a === v.toInt, 'a === v) assertEquivalent('a <=> v.toInt, 'a === v) assertEquivalent('a <= v.toInt, 'a === v) - assertEquivalent('a < v.toInt, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a < v.toInt, 'a.attr.falseIfNotNull) } test("unwrap casts when literal < min") { - val v: Int = Short.MinValue - 100 - assertEquivalent('a > v, UnwrapCast.trueIfNotNull('a)) - assertEquivalent('a >= v, UnwrapCast.trueIfNotNull('a)) - assertEquivalent('a === v, UnwrapCast.falseIfNotNull('a)) + val v: Int = negativeInt + assertEquivalent('a > v, 'a.attr.trueIfNotNull) + assertEquivalent('a >= v, 'a.attr.trueIfNotNull) + assertEquivalent('a === v, 'a.attr.falseIfNotNull) assertEquivalent('a <=> v, false) - assertEquivalent('a <= v, UnwrapCast.falseIfNotNull('a)) - assertEquivalent('a < v, UnwrapCast.falseIfNotNull('a)) + assertEquivalent('a <= v, 'a.attr.falseIfNotNull) + assertEquivalent('a < v, 'a.attr.falseIfNotNull) } test("unwrap casts when literal is within range (min, max)") { @@ -85,11 +88,11 @@ class UnwrapCastSuite extends PlanTest { test("unwrap casts when cast is on rhs") { val v = Short.MaxValue - assertEquivalent(Literal(v.toInt) < 'a, UnwrapCast.falseIfNotNull('a)) + assertEquivalent(Literal(v.toInt) < 'a, 'a.attr.falseIfNotNull) assertEquivalent(Literal(v.toInt) <= 'a, Literal(v) === 'a) assertEquivalent(Literal(v.toInt) === 'a, Literal(v) === 'a) assertEquivalent(Literal(v.toInt) <=> 'a, Literal(v) === 'a) - assertEquivalent(Literal(v.toInt) >= 'a, UnwrapCast.trueIfNotNull('a)) + assertEquivalent(Literal(v.toInt) >= 'a, 'a.attr.trueIfNotNull) assertEquivalent(Literal(v.toInt) > 'a, 'a =!= v) assertEquivalent(Literal(30) <= 'a, Literal(30.toShort) <= 'a) @@ -110,6 +113,17 @@ class UnwrapCastSuite extends PlanTest { assertEquivalent(Literal(42.0) < 'b, Literal(42.0) < Cast('b, DoubleType)) } + test("unwrap casts when literal is null") { + val intLit = Literal.create(null, IntegerType) + val nullLit = Literal.create(null, BooleanType) + assertEquivalent('a > intLit, nullLit) + assertEquivalent('a >= intLit, nullLit) + assertEquivalent('a === intLit, nullLit) + assertEquivalent('a <=> intLit, IsNull(Cast('a, IntegerType))) + assertEquivalent('a <= intLit, nullLit) + assertEquivalent('a < intLit, nullLit) + } + private def assertEquivalent(e1: Expression, e2: Expression): Unit = { val plan = testRelation.where(e1).analyze val actual = Optimize.execute(plan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index a3cd0c230d8af..5eae588a7b7f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -31,12 +31,14 @@ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, FileScan} -import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetTable +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan +import org.apache.spark.sql.execution.datasources.v2.parquet.{ParquetScan, ParquetTable} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -881,6 +883,90 @@ class FileBasedDataSourceSuite extends QueryTest } } } + + test("test casts pushdown on orc/parquet for integral types") { + def checkPushedFilters( + format: String, + df: DataFrame, + filters: Array[sources.Filter], + noScan: Boolean = false): Unit = { + val scanExec = df.queryExecution.sparkPlan.find(_.isInstanceOf[BatchScanExec]) + if (noScan) { + assert(scanExec.isEmpty) + return + } + val scan = scanExec.get.asInstanceOf[BatchScanExec].scan + format match { + case "orc" => + assert(scan.isInstanceOf[OrcScan]) + assert(scan.asInstanceOf[OrcScan].pushedFilters === filters) + case "parquet" => + assert(scan.isInstanceOf[ParquetScan]) + assert(scan.asInstanceOf[ParquetScan].pushedFilters === filters) + case _ => + fail(s"unknown format $format") + } + } + + Seq("orc", "parquet").foreach { format => + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempPath { dir => + spark.range(100).map(i => (i.toShort, i.toString)).toDF("id", "s") + .write + .format(format) + .save(dir.getCanonicalPath) + val df = spark.read.format(format).load(dir.getCanonicalPath) + + // cases when value == MAX + var v = Short.MaxValue + checkPushedFilters(format, df.where('id > v.toInt), Array(), noScan = true) + checkPushedFilters(format, df.where('id >= v.toInt), Array(sources.IsNotNull("id"), + sources.EqualTo("id", v))) + checkPushedFilters(format, df.where('id === v.toInt), Array(sources.IsNotNull("id"), + sources.EqualTo("id", v))) + checkPushedFilters(format, df.where('id <=> v.toInt), Array(sources.IsNotNull("id"), + sources.EqualTo("id", v))) + checkPushedFilters(format, df.where('id <= v.toInt), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where('id < v.toInt), Array(sources.IsNotNull("id"), + sources.Not(sources.EqualTo("id", v)))) + + // cases when value > MAX + var v1: Int = positiveInt + checkPushedFilters(format, df.where('id > v1), Array(), noScan = true) + checkPushedFilters(format, df.where('id >= v1), Array(), noScan = true) + checkPushedFilters(format, df.where('id === v1), Array(), noScan = true) + checkPushedFilters(format, df.where('id <=> v1), Array(), noScan = true) + checkPushedFilters(format, df.where('id <= v1), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where('id < v1), Array(sources.IsNotNull("id"))) + + // cases when value = MIN + v = Short.MinValue + checkPushedFilters(format, df.where(lit(v.toInt) < 'id), Array(sources.IsNotNull("id"), + sources.Not(sources.EqualTo("id", v)))) + checkPushedFilters(format, df.where(lit(v.toInt) <= 'id), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(lit(v.toInt) === 'id), Array(sources.IsNotNull("id"), + sources.EqualTo("id", v))) + checkPushedFilters(format, df.where(lit(v.toInt) >= 'id), Array(sources.IsNotNull("id"), + sources.EqualTo("id", v))) + checkPushedFilters(format, df.where(lit(v.toInt) > 'id), Array(), noScan = true) + + // cases when value < MIN + v1 = negativeInt + checkPushedFilters(format, df.where(lit(v1) < 'id), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(lit(v1) <= 'id), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(lit(v1) === 'id), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v1) >= 'id), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v1) > 'id), Array(), noScan = true) + + // cases when value is within range (MIN, MAX) + checkPushedFilters(format, df.where('id > 30), Array(sources.IsNotNull("id"), + sources.GreaterThan("id", 30))) + checkPushedFilters(format, df.where(lit(100) >= 'id), Array(sources.IsNotNull("id"), + sources.LessThanOrEqual("id", 100))) + } + } + } + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 739131f3bcc32..e20606c3338d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3716,137 +3716,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark fail(s"unknown format $format") } } - - Seq("orc", "parquet").foreach { format => - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { - withTempPath { dir => - spark.range(100).map(i => (i.toShort, i.toString)).toDF("id", "s") - .write - .format(format) - .save(dir.getCanonicalPath) - val df = spark.read.format(format).load(dir.getCanonicalPath) - - // cases when value == MAX - var v = Short.MaxValue - checkPushedFilters( - format, - df.where('id > v.toInt), - Array(), - noScan = true) - checkPushedFilters( - format, - df.where('id >= v.toInt), - Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters( - format, - df.where('id === v.toInt), - Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters( - format, - df.where('id <=> v.toInt), - Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters( - format, - df.where('id <= v.toInt), - Array(sources.IsNotNull("id"))) - checkPushedFilters( - format, - df.where('id < v.toInt), - Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) - - // cases when value > MAX - var v1: Int = Short.MaxValue.toInt + 100 - checkPushedFilters( - format, - df.where('id > v1), - Array(), - noScan = true) - checkPushedFilters( - format, - df.where('id >= v1), - Array(), - noScan = true) - checkPushedFilters( - format, - df.where('id === v1), - Array(), - noScan = true) - checkPushedFilters( - format, - df.where('id <=> v1), - Array(), - noScan = true) - checkPushedFilters( - format, - df.where('id <= v1), - Array(sources.IsNotNull("id"))) - checkPushedFilters( - format, - df.where('id < v1), - Array(sources.IsNotNull("id"))) - - // cases when value = MIN - v = Short.MinValue - checkPushedFilters( - format, - df.where(lit(v.toInt) < 'id), - Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) - checkPushedFilters( - format, - df.where(lit(v.toInt) <= 'id), - Array(sources.IsNotNull("id"))) - checkPushedFilters( - format, - df.where(lit(v.toInt) === 'id), - Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters( - format, - df.where(lit(v.toInt) >= 'id), - Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters( - format, - df.where(lit(v.toInt) > 'id), - Array(), - noScan = true) - - // cases when value < MIN - v1 = Short.MinValue.toInt - 100 - checkPushedFilters( - format, - df.where(lit(v1) < 'id), - Array(sources.IsNotNull("id"))) - checkPushedFilters( - format, - df.where(lit(v1) <= 'id), - Array(sources.IsNotNull("id"))) - checkPushedFilters( - format, - df.where(lit(v1) === 'id), - Array(), - noScan = true) - checkPushedFilters( - format, - df.where(lit(v1) >= 'id), - Array(), - noScan = true) - checkPushedFilters( - format, - df.where(lit(v1) > 'id), - Array(), - noScan = true) - - // cases when value is within range (MIN, MAX) - checkPushedFilters( - format, - df.where('id > 30), - Array(sources.IsNotNull("id"), sources.GreaterThan("id", 30))) - checkPushedFilters( - format, - df.where(lit(100) >= 'id), - Array(sources.IsNotNull("id"), sources.LessThanOrEqual("id", 100))) - } - } - } } } From 87e6ff8ae6185cd560b1c6604d24218324c44efd Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 31 Aug 2020 18:26:07 -0700 Subject: [PATCH 04/17] Improve doc --- .../UnwrapCastInBinaryComparison.scala | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 331bb9215a099..8b75ac50bba7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -44,26 +44,26 @@ import org.apache.spark.sql.types._ * cases and try to replace each with simpler constructs. * * if `value > max`, the cases are of following: - * - `cast(exp, ty) > value` ==> if(isnull(exp), null, false) - * - `cast(exp, ty) >= value` ==> if(isnull(exp), null, false) - * - `cast(exp, ty) === value` ==> if(isnull(exp), null, false) - * - `cast(exp, ty) <=> value` ==> false - * - `cast(exp, ty) <= value` ==> if(isnull(exp), null, true) - * - `cast(exp, ty) < value` ==> if(isnull(exp), null, true) + * - `cast(fromExp, toType) > value` ==> if(isnull(fromExp), null, false) + * - `cast(fromExp, toType) >= value` ==> if(isnull(fromExp), null, false) + * - `cast(fromExp, toType) === value` ==> if(isnull(fromExp), null, false) + * - `cast(fromExp, toType) <=> value` ==> false + * - `cast(fromExp, toType) <= value` ==> if(isnull(fromExp), null, true) + * - `cast(fromExp, toType) < value` ==> if(isnull(fromExp), null, true) * * if `value == max`, the cases are of following: - * - `cast(exp, ty) > value` ==> if(isnull(exp), null, false) - * - `cast(exp, ty) >= value` ==> exp == max - * - `cast(exp, ty) === value` ==> exp == max - * - `cast(exp, ty) <=> value` ==> exp == max - * - `cast(exp, ty) <= value` ==> if(isnull(exp), null, true) - * - `cast(exp, ty) < value` ==> exp =!= max + * - `cast(fromExp, toType) > value` ==> if(isnull(fromExp), null, false) + * - `cast(fromExp, toType) >= value` ==> fromExp == max + * - `cast(fromExp, toType) === value` ==> fromExp == max + * - `cast(fromExp, toType) <=> value` ==> fromExp == max + * - `cast(fromExp, toType) <= value` ==> if(isnull(fromExp), null, true) + * - `cast(fromExp, toType) < value` ==> fromExp =!= max * * Similarly for the cases when `value == min` and `value < min`. * - * Further, the above `if(isnull(exp), null, false)` is represented using conjunction - * `and(isnull(exp), null)`, to enable further optimization and filter pushdown to data sources. - * Similarly, `if(isnull(exp), null, true)` is represented with `or(isnotnull(exp), null)`. + * Further, the above `if(isnull(fromExp), null, false)` is represented using conjunction + * `and(isnull(fromExp), null)`, to enable further optimization and filter pushdown to data sources. + * Similarly, `if(isnull(fromExp), null, true)` is represented with `or(isnotnull(fromExp), null)`. */ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { From fa82112dfd140fb698983cb617addeedef4f4ce7 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 31 Aug 2020 19:06:49 -0700 Subject: [PATCH 05/17] Shouldn't skip evaluation if exp is non-deterministic --- .../catalyst/optimizer/UnwrapCastInBinaryComparison.scala | 8 +++++--- .../optimizer/UnwrapCastInBinaryComparisonSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 8b75ac50bba7f..d5f46a4e90292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -90,7 +90,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { swap(unwrapCast(swap(exp))) - case BinaryComparison(Cast(fromExp, _, _), Literal(value, toType: IntegralType)) + case BinaryComparison(Cast(fromExp, _: IntegralType, _), Literal(value, toType: IntegralType)) if canImplicitlyCast(fromExp, toType) => // In case both sides have integral type, optimize the comparison by removing casts or // moving cast to the literal side. @@ -128,7 +128,8 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { case LessThan(_, _) | LessThanOrEqual(_, _) => fromExp.trueIfNotNull case EqualNullSafe(_, _) => - FalseLiteral + // make sure the expression is evaluated if it is non-deterministic + if (exp.deterministic) FalseLiteral else exp case _ => exp // impossible but safe guard, same below } } else if (maxCmp == 0) { @@ -150,7 +151,8 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) => fromExp.falseIfNotNull case EqualNullSafe(_, _) => - FalseLiteral + // make sure the expression is evaluated if it is non-deterministic + if (exp.deterministic) FalseLiteral else exp case _ => exp } } else if (minCmp == 0) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 8606eef737816..e4ec0124cc418 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils._ +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ @@ -113,6 +114,13 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest { assertEquivalent(Literal(42.0) < 'b, Literal(42.0) < Cast('b, DoubleType)) } + test("unwrap cast should skip when expression is non-deterministic") { + assertEquivalent(Cast(First('a, ignoreNulls = true), IntegerType) <=> positiveInt, + Cast(First('a, ignoreNulls = true), IntegerType) <=> positiveInt) + assertEquivalent(Cast(First('a, ignoreNulls = true), IntegerType) <=> negativeInt, + Cast(First('a, ignoreNulls = true), IntegerType) <=> negativeInt) + } + test("unwrap casts when literal is null") { val intLit = Literal.create(null, IntegerType) val nullLit = Literal.create(null, BooleanType) From 92943b488c931b675820d4dd674157c148395d4f Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 1 Sep 2020 10:44:02 -0700 Subject: [PATCH 06/17] Minor fixes on style --- .../UnwrapCastInBinaryComparison.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index d5f46a4e90292..e7903e80829c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -74,10 +74,10 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { } private def unwrapCast(exp: Expression): Expression = exp match { + // Not a canonical form. In this case we first canonicalize the expression by swapping the + // literal and cast side, then process the result and swap the literal and cast again to + // restore the original order. case BinaryComparison(Literal(_, _), Cast(_, _, _)) => - // Not a canonical form. In this case we first canonicalize the expression by swapping the - // literal and cast side, then process the result and swap the literal and cast again to - // restore the original order. def swap(e: Expression): Expression = e match { case GreaterThan(left, right) => LessThan(right, left) case GreaterThanOrEqual(left, right) => LessThanOrEqual(right, left) @@ -87,14 +87,14 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { case LessThan(left, right) => GreaterThan(right, left) case _ => e } - swap(unwrapCast(swap(exp))) - case BinaryComparison(Cast(fromExp, _: IntegralType, _), Literal(value, toType: IntegralType)) + // In case both sides have integral type, optimize the comparison by removing casts or + // moving cast to the literal side. + case be @ BinaryComparison( + Cast(fromExp, _: IntegralType, _), Literal(value, toType: IntegralType)) if canImplicitlyCast(fromExp, toType) => - // In case both sides have integral type, optimize the comparison by removing casts or - // moving cast to the literal side. - simplifyIntegral(exp, fromExp, toType, value) + simplifyIntegral(be, fromExp, toType, value) case _ => exp } @@ -106,7 +106,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { * true, this replaces the input binary comparison `exp` with simpler expressions. */ private def simplifyIntegral( - exp: Expression, + exp: BinaryComparison, fromExp: Expression, toType: IntegralType, value: Any): Expression = { @@ -124,7 +124,6 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { exp match { case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) => fromExp.falseIfNotNull - // falseIfNotNull(fromExp) case LessThan(_, _) | LessThanOrEqual(_, _) => fromExp.trueIfNotNull case EqualNullSafe(_, _) => @@ -190,7 +189,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { */ private def canImplicitlyCast(fromExp: Expression, toType: DataType): Boolean = { fromExp.dataType.isInstanceOf[IntegralType] && toType.isInstanceOf[IntegralType] && - Cast.canUpCast(fromExp.dataType, toType) + Cast.canUpCast(fromExp.dataType, toType) } private def getRange(dt: DataType): (Any, Any) = dt match { From d868e0f9b1a998037bdf21b1d8fc4ce6a2e27ee3 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 1 Sep 2020 10:53:59 -0700 Subject: [PATCH 07/17] Validate input expression before swapping --- .../sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index e7903e80829c2..cf85450426bd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -77,7 +77,8 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { // Not a canonical form. In this case we first canonicalize the expression by swapping the // literal and cast side, then process the result and swap the literal and cast again to // restore the original order. - case BinaryComparison(Literal(_, _), Cast(_, _, _)) => + case BinaryComparison(Literal(_, toType: IntegralType), Cast(fromExp, _: IntegralType, _)) + if canImplicitlyCast(fromExp, toType) => def swap(e: Expression): Expression = e match { case GreaterThan(left, right) => LessThan(right, left) case GreaterThanOrEqual(left, right) => LessThanOrEqual(right, left) From a2cabb5f04cabe8cdf7656735ae332bdb8a64cf9 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 1 Sep 2020 13:30:54 -0700 Subject: [PATCH 08/17] Nit --- .../optimizer/UnwrapCastInBinaryComparison.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index cf85450426bd0..19457c9a9b4bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -127,9 +127,9 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { fromExp.falseIfNotNull case LessThan(_, _) | LessThanOrEqual(_, _) => fromExp.trueIfNotNull - case EqualNullSafe(_, _) => - // make sure the expression is evaluated if it is non-deterministic - if (exp.deterministic) FalseLiteral else exp + // make sure the expression is evaluated if it is non-deterministic + case EqualNullSafe(_, _) if exp.deterministic => + FalseLiteral case _ => exp // impossible but safe guard, same below } } else if (maxCmp == 0) { @@ -150,9 +150,9 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { fromExp.trueIfNotNull case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) => fromExp.falseIfNotNull - case EqualNullSafe(_, _) => - // make sure the expression is evaluated if it is non-deterministic - if (exp.deterministic) FalseLiteral else exp + // make sure the expression is evaluated if it is non-deterministic + case EqualNullSafe(_, _) if exp.deterministic => + FalseLiteral case _ => exp } } else if (minCmp == 0) { From b59b905243ec5189713e77e711074fc78d98898b Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 9 Sep 2020 01:20:59 -0700 Subject: [PATCH 09/17] Remove unused test --- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e20606c3338d2..b86df4db816b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3691,32 +3691,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("SELECT id FROM t WHERE (SELECT true)"), Row(0L)) } } - - test("test casts pushdown on orc/parquet for integral types") { - def checkPushedFilters( - format: String, - df: DataFrame, - filters: Array[sources.Filter], - noScan: Boolean = false): Unit = { - val scanExec = df.queryExecution.sparkPlan - .find(_.isInstanceOf[BatchScanExec]) - if (noScan) { - assert(scanExec.isEmpty) - return - } - val scan = scanExec.get.asInstanceOf[BatchScanExec].scan - format match { - case "orc" => - assert(scan.isInstanceOf[OrcScan]) - assert(scan.asInstanceOf[OrcScan].pushedFilters === filters) - case "parquet" => - assert(scan.isInstanceOf[ParquetScan]) - assert(scan.asInstanceOf[ParquetScan].pushedFilters === filters) - case _ => - fail(s"unknown format $format") - } - } - } } case class Foo(bar: Option[String]) From 3a00dce001d9e2d10b917041573707c8f19cb604 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 9 Sep 2020 00:43:23 -0700 Subject: [PATCH 10/17] Nits --- .../sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 19457c9a9b4bf..69c6be12a15e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.types._ * - `cast(fromExp, toType) > value` ==> if(isnull(fromExp), null, false) * - `cast(fromExp, toType) >= value` ==> if(isnull(fromExp), null, false) * - `cast(fromExp, toType) === value` ==> if(isnull(fromExp), null, false) - * - `cast(fromExp, toType) <=> value` ==> false + * - `cast(fromExp, toType) <=> value` ==> false (only if `fromExp` is deterministic) * - `cast(fromExp, toType) <= value` ==> if(isnull(fromExp), null, true) * - `cast(fromExp, toType) < value` ==> if(isnull(fromExp), null, true) * @@ -93,7 +93,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { // In case both sides have integral type, optimize the comparison by removing casts or // moving cast to the literal side. case be @ BinaryComparison( - Cast(fromExp, _: IntegralType, _), Literal(value, toType: IntegralType)) + Cast(fromExp, toType: IntegralType, _), Literal(value, _: IntegralType)) if canImplicitlyCast(fromExp, toType) => simplifyIntegral(be, fromExp, toType, value) From 265c1698c29730b0dee2865773557d7a3c4c1144 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 9 Sep 2020 00:50:56 -0700 Subject: [PATCH 11/17] Switch to pattern matching --- .../UnwrapCastInBinaryComparison.scala | 106 ++++++++---------- 1 file changed, 47 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 69c6be12a15e8..a911fdd5c67b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -120,66 +120,54 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { val ordering = toType.ordering.asInstanceOf[Ordering[Any]] val minCmp = ordering.compare(value, minInToType) val maxCmp = ordering.compare(value, maxInToType) + val lit = Cast(Literal(value), fromType) - if (maxCmp > 0) { - exp match { - case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) => - fromExp.falseIfNotNull - case LessThan(_, _) | LessThanOrEqual(_, _) => - fromExp.trueIfNotNull - // make sure the expression is evaluated if it is non-deterministic - case EqualNullSafe(_, _) if exp.deterministic => - FalseLiteral - case _ => exp // impossible but safe guard, same below - } - } else if (maxCmp == 0) { - exp match { - case GreaterThan(_, _) => - fromExp.falseIfNotNull - case LessThanOrEqual(_, _) => - fromExp.trueIfNotNull - case LessThan(_, _) => - Not(EqualTo(fromExp, Literal(max, fromType))) - case GreaterThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => - EqualTo(fromExp, Literal(max, fromType)) - case _ => exp - } - } else if (minCmp < 0) { - exp match { - case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => - fromExp.trueIfNotNull - case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) => - fromExp.falseIfNotNull - // make sure the expression is evaluated if it is non-deterministic - case EqualNullSafe(_, _) if exp.deterministic => - FalseLiteral - case _ => exp - } - } else if (minCmp == 0) { - exp match { - case LessThan(_, _) => - fromExp.falseIfNotNull - case GreaterThanOrEqual(_, _) => - fromExp.trueIfNotNull - case GreaterThan(_, _) => - Not(EqualTo(fromExp, Literal(min, fromType))) - case LessThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _) => - EqualTo(fromExp, Literal(min, fromType)) - case _ => exp - } - } else { - // This means `value` is within range `(min, max)`. Optimize this by moving the cast to the - // literal side. - val lit = Cast(Literal(value), fromType) - exp match { - case GreaterThan(_, _) => GreaterThan(fromExp, lit) - case GreaterThanOrEqual(_, _) => GreaterThanOrEqual(fromExp, lit) - case EqualTo(_, _) => EqualTo(fromExp, lit) - case EqualNullSafe(_, _) => EqualNullSafe(fromExp, lit) - case LessThan(_, _) => LessThan(fromExp, lit) - case LessThanOrEqual(_, _) => LessThanOrEqual(fromExp, lit) - case _ => exp - } + (minCmp.signum, maxCmp.signum, exp) match { + case (_, 1, EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _)) => + fromExp.falseIfNotNull + case (_, 1, LessThan(_, _) | LessThanOrEqual(_, _)) => + fromExp.trueIfNotNull + // make sure the expression is evaluated if it is non-deterministic + case (_, 1, EqualNullSafe(_, _)) if exp.deterministic => + FalseLiteral + case (_, 1, _) => exp // impossible but safe guard, same below + + case (_, 0, GreaterThan(_, _)) => + fromExp.falseIfNotNull + case (_, 0, LessThanOrEqual(_, _)) => + fromExp.trueIfNotNull + case (_, 0, LessThan(_, _)) => + Not(EqualTo(fromExp, Literal(max, fromType))) + case (_, 0, GreaterThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _)) => + EqualTo(fromExp, Literal(max, fromType)) + case (_, 0, _) => exp + + case (-1, _, GreaterThan(_, _) | GreaterThanOrEqual(_, _)) => + fromExp.trueIfNotNull + case (-1, _, LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _)) => + fromExp.falseIfNotNull + // make sure the expression is evaluated if it is non-deterministic + case (-1, _, EqualNullSafe(_, _)) if exp.deterministic => + FalseLiteral + case (-1, _, _) => exp + + case (0, _, LessThan(_, _)) => + fromExp.falseIfNotNull + case (0, _, GreaterThanOrEqual(_, _)) => + fromExp.trueIfNotNull + case (0, _, GreaterThan(_, _)) => + Not(EqualTo(fromExp, Literal(min, fromType))) + case (0, _, LessThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _)) => + EqualTo(fromExp, Literal(min, fromType)) + case (0, _, _) => exp + + case (_, _, GreaterThan(_, _)) => GreaterThan(fromExp, lit) + case (_, _, GreaterThanOrEqual(_, _)) => GreaterThanOrEqual(fromExp, lit) + case (_, _, EqualTo(_, _)) => EqualTo(fromExp, lit) + case (_, _, EqualNullSafe(_, _)) => EqualNullSafe(fromExp, lit) + case (_, _, LessThan(_, _)) => LessThan(fromExp, lit) + case (_, _, LessThanOrEqual(_, _)) => LessThanOrEqual(fromExp, lit) + case (_, _, _) => exp } } From 71ff2c2e7b087b42cf8622eb456d8b0db7ead93e Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 10 Sep 2020 10:59:14 -0700 Subject: [PATCH 12/17] Address comments - Add comments on type coercion constraint - Add comment for non-deterministic case in `EqualNullSafe` - Rename `simplifyIntegral` to `simplifyIntegralComparison` - Refactor test - Fix nit in pattern matching --- .../UnwrapCastInBinaryComparison.scala | 29 ++++++++++++------- .../UnwrapCastInBinaryComparisonSuite.scala | 8 ++--- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index a911fdd5c67b2..bb721380a3089 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -34,11 +34,18 @@ import org.apache.spark.sql.types._ * constructs, or moving the cast from the expression side to the literal side, which enables them * to be optimized away later and pushed down to data sources. * - * Currently this only handles cases where `fromType` (of `fromExp`) and `toType` are of integral - * types (i.e., byte, short, int and long). The rule checks to see if the literal `value` is - * within range `(min, max)`, where `min` and `max` are the minimum and maximum value of - * `fromType`, respectively. If this is true then it means we can safely cast `value` to `fromType` - * and thus able to move the cast to the literal side. + * Currently this only handles cases where: + * 1). `fromType` (of `fromExp`) and `toType` are of integral types (i.e., byte, short, int and + * long) + * 2). `fromType` can be safely coerced to `toType` without precision loss (e.g., short to int, + * int to long, but not long to int) + * + * If the above conditions are satisfied, the rule checks to see if the literal `value` is within + * range `(min, max)`, where `min` and `max` are the minimum and maximum value of `fromType`, + * respectively. If this is true then it means we can safely cast `value` to `fromType` and thus + * able to move the cast to the literal side. That is: + * + * `cast(fromExp, toType) op value` ==> `fromExp op cast(value, fromType)` * * If the `value` is not within range `(min, max)`, the rule breaks the scenario into different * cases and try to replace each with simpler constructs. @@ -47,7 +54,9 @@ import org.apache.spark.sql.types._ * - `cast(fromExp, toType) > value` ==> if(isnull(fromExp), null, false) * - `cast(fromExp, toType) >= value` ==> if(isnull(fromExp), null, false) * - `cast(fromExp, toType) === value` ==> if(isnull(fromExp), null, false) - * - `cast(fromExp, toType) <=> value` ==> false (only if `fromExp` is deterministic) + * - `cast(fromExp, toType) <=> value` ==> false (if `fromExp` is deterministic) + * - `cast(fromExp, toType) <=> value` ==> fromExp <=> cast(value, fromExp) (if `fromExp` is + * non-deterministic) * - `cast(fromExp, toType) <= value` ==> if(isnull(fromExp), null, true) * - `cast(fromExp, toType) < value` ==> if(isnull(fromExp), null, true) * @@ -95,7 +104,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { case be @ BinaryComparison( Cast(fromExp, toType: IntegralType, _), Literal(value, _: IntegralType)) if canImplicitlyCast(fromExp, toType) => - simplifyIntegral(be, fromExp, toType, value) + simplifyIntegralComparison(be, fromExp, toType, value) case _ => exp } @@ -106,7 +115,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { * optimizes the expression by moving the cast to the literal side. Otherwise if result is not * true, this replaces the input binary comparison `exp` with simpler expressions. */ - private def simplifyIntegral( + private def simplifyIntegralComparison( exp: BinaryComparison, fromExp: Expression, toType: IntegralType, @@ -130,7 +139,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { // make sure the expression is evaluated if it is non-deterministic case (_, 1, EqualNullSafe(_, _)) if exp.deterministic => FalseLiteral - case (_, 1, _) => exp // impossible but safe guard, same below + case (_, 1, _) => exp case (_, 0, GreaterThan(_, _)) => fromExp.falseIfNotNull @@ -167,7 +176,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { case (_, _, EqualNullSafe(_, _)) => EqualNullSafe(fromExp, lit) case (_, _, LessThan(_, _)) => LessThan(fromExp, lit) case (_, _, LessThanOrEqual(_, _)) => LessThanOrEqual(fromExp, lit) - case (_, _, _) => exp + case _ => exp } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index e4ec0124cc418..748459d687258 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -115,10 +115,10 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest { } test("unwrap cast should skip when expression is non-deterministic") { - assertEquivalent(Cast(First('a, ignoreNulls = true), IntegerType) <=> positiveInt, - Cast(First('a, ignoreNulls = true), IntegerType) <=> positiveInt) - assertEquivalent(Cast(First('a, ignoreNulls = true), IntegerType) <=> negativeInt, - Cast(First('a, ignoreNulls = true), IntegerType) <=> negativeInt) + Seq(positiveInt, negativeInt).foreach (v => { + val e = Cast(First('a, ignoreNulls = true), IntegerType) <=> v + assertEquivalent(e, e) + }) } test("unwrap casts when literal is null") { From fc91795cfc0268f7ebfeaa4697c760708980199b Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 10 Sep 2020 14:41:33 -0700 Subject: [PATCH 13/17] Fix a bug when optimizing EqualNullSafe --- .../UnwrapCastInBinaryComparison.scala | 8 +- .../UnwrapCastInBinaryComparisonSuite.scala | 140 ++++++++++-------- 2 files changed, 85 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index bb721380a3089..b845fe7b6de7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -147,8 +147,10 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { fromExp.trueIfNotNull case (_, 0, LessThan(_, _)) => Not(EqualTo(fromExp, Literal(max, fromType))) - case (_, 0, GreaterThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _)) => + case (_, 0, GreaterThanOrEqual(_, _) | EqualTo(_, _)) => EqualTo(fromExp, Literal(max, fromType)) + case (_, 0, EqualNullSafe(_, _)) => + EqualNullSafe(fromExp, Literal(max, fromType)) case (_, 0, _) => exp case (-1, _, GreaterThan(_, _) | GreaterThanOrEqual(_, _)) => @@ -166,8 +168,10 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { fromExp.trueIfNotNull case (0, _, GreaterThan(_, _)) => Not(EqualTo(fromExp, Literal(min, fromType))) - case (0, _, LessThanOrEqual(_, _) | EqualTo(_, _) | EqualNullSafe(_, _)) => + case (0, _, LessThanOrEqual(_, _) | EqualTo(_, _)) => EqualTo(fromExp, Literal(min, fromType)) + case (0, _, EqualNullSafe(_, _)) => + EqualNullSafe(fromExp, Literal(min, fromType)) case (0, _, _) => exp case (_, _, GreaterThan(_, _)) => GreaterThan(fromExp, lit) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 748459d687258..4ee63be442e27 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType} +import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, IntegerType} -class UnwrapCastInBinaryComparisonSuite extends PlanTest { +class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches: List[Batch] = @@ -37,105 +37,123 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest { } val testRelation: LocalRelation = LocalRelation('a.short, 'b.float) + val f = 'a.short.canBeNull.at(0) test("unwrap casts when literal == max") { val v = Short.MaxValue - assertEquivalent('a > v.toInt, 'a.attr.falseIfNotNull) - assertEquivalent('a >= v.toInt, 'a === v) - assertEquivalent('a === v.toInt, 'a === v) - assertEquivalent('a <=> v.toInt, 'a === v) - assertEquivalent('a <= v.toInt, 'a.attr.trueIfNotNull) - assertEquivalent('a < v.toInt, 'a =!= v) + assertEquivalent(castInt(f) > v.toInt, f.falseIfNotNull) + assertEquivalent(castInt(f) >= v.toInt, f === v) + assertEquivalent(castInt(f) === v.toInt, f === v) + assertEquivalent(castInt(f) <=> v.toInt, f <=> v) + assertEquivalent(castInt(f) <= v.toInt, f.trueIfNotNull) + assertEquivalent(castInt(f) < v.toInt, f =!= v) } test("unwrap casts when literal > max") { val v: Int = positiveInt - assertEquivalent('a > v, 'a.attr.falseIfNotNull) - assertEquivalent('a >= v, 'a.attr.falseIfNotNull) - assertEquivalent('a === v, 'a.attr.falseIfNotNull) - assertEquivalent('a <=> v, false) - assertEquivalent('a <= v, 'a.attr.trueIfNotNull) - assertEquivalent('a < v, 'a.attr.trueIfNotNull) + assertEquivalent(castInt(f) > v, f.falseIfNotNull) + assertEquivalent(castInt(f) >= v, f.falseIfNotNull) + assertEquivalent(castInt(f) === v, f.falseIfNotNull) + assertEquivalent(castInt(f) <=> v, false) + assertEquivalent(castInt(f) <= v, f.trueIfNotNull) + assertEquivalent(castInt(f) < v, f.trueIfNotNull) } test("unwrap casts when literal == min") { val v = Short.MinValue - assertEquivalent('a > v.toInt, 'a =!= v) - assertEquivalent('a >= v.toInt, 'a.attr.trueIfNotNull) - assertEquivalent('a === v.toInt, 'a === v) - assertEquivalent('a <=> v.toInt, 'a === v) - assertEquivalent('a <= v.toInt, 'a === v) - assertEquivalent('a < v.toInt, 'a.attr.falseIfNotNull) + assertEquivalent(castInt(f) > v.toInt, f =!= v) + assertEquivalent(castInt(f) >= v.toInt, f.trueIfNotNull) + assertEquivalent(castInt(f) === v.toInt, f === v) + assertEquivalent(castInt(f) <=> v.toInt, f <=> v) + assertEquivalent(castInt(f) <= v.toInt, f === v) + assertEquivalent(castInt(f) < v.toInt, f.falseIfNotNull) } test("unwrap casts when literal < min") { val v: Int = negativeInt - assertEquivalent('a > v, 'a.attr.trueIfNotNull) - assertEquivalent('a >= v, 'a.attr.trueIfNotNull) - assertEquivalent('a === v, 'a.attr.falseIfNotNull) - assertEquivalent('a <=> v, false) - assertEquivalent('a <= v, 'a.attr.falseIfNotNull) - assertEquivalent('a < v, 'a.attr.falseIfNotNull) + assertEquivalent(castInt(f) > v, f.trueIfNotNull) + assertEquivalent(castInt(f) >= v, f.trueIfNotNull) + assertEquivalent(castInt(f) === v, f.falseIfNotNull) + assertEquivalent(castInt(f) <=> v, false) + assertEquivalent(castInt(f) <= v, f.falseIfNotNull) + assertEquivalent(castInt(f) < v, f.falseIfNotNull) } test("unwrap casts when literal is within range (min, max)") { - assertEquivalent('a > 300, 'a > 300.toShort) - assertEquivalent('a >= 500, 'a >= 500.toShort) - assertEquivalent('a === 32766, 'a === 32766.toShort) - assertEquivalent('a <=> 32766, 'a <=> 32766.toShort) - assertEquivalent('a <= -6000, 'a <= -6000.toShort) - assertEquivalent('a < -32767, 'a < -32767.toShort) + assertEquivalent(castInt(f) > 300, f > 300.toShort) + assertEquivalent(castInt(f) >= 500, f >= 500.toShort) + assertEquivalent(castInt(f) === 32766, f === 32766.toShort) + assertEquivalent(castInt(f) <=> 32766, f <=> 32766.toShort) + assertEquivalent(castInt(f) <= -6000, f <= -6000.toShort) + assertEquivalent(castInt(f) < -32767, f < -32767.toShort) } test("unwrap casts when cast is on rhs") { val v = Short.MaxValue - assertEquivalent(Literal(v.toInt) < 'a, 'a.attr.falseIfNotNull) - assertEquivalent(Literal(v.toInt) <= 'a, Literal(v) === 'a) - assertEquivalent(Literal(v.toInt) === 'a, Literal(v) === 'a) - assertEquivalent(Literal(v.toInt) <=> 'a, Literal(v) === 'a) - assertEquivalent(Literal(v.toInt) >= 'a, 'a.attr.trueIfNotNull) - assertEquivalent(Literal(v.toInt) > 'a, 'a =!= v) - - assertEquivalent(Literal(30) <= 'a, Literal(30.toShort) <= 'a) + assertEquivalent(Literal(v.toInt) < castInt(f), f.falseIfNotNull) + assertEquivalent(Literal(v.toInt) <= castInt(f), Literal(v) === f) + assertEquivalent(Literal(v.toInt) === castInt(f), Literal(v) === f) + assertEquivalent(Literal(v.toInt) <=> castInt(f), Literal(v) <=> f) + assertEquivalent(Literal(v.toInt) >= castInt(f), f.trueIfNotNull) + assertEquivalent(Literal(v.toInt) > castInt(f), f =!= v) + + assertEquivalent(Literal(30) <= castInt(f), Literal(30.toShort) <= f) } test("unwrap cast should have no effect when input is not integral type") { - assertEquivalent('b > 42.0, Cast('b, DoubleType) > 42.0) - assertEquivalent('b >= 42.0, Cast('b, DoubleType) >= 42.0) - assertEquivalent('b === 42.0, Cast('b, DoubleType) === 42.0) - assertEquivalent('b <=> 42.0, Cast('b, DoubleType) <=> 42.0) - assertEquivalent('b <= 42.0, Cast('b, DoubleType) <= 42.0) - assertEquivalent('b < 42.0, Cast('b, DoubleType) < 42.0) - assertEquivalent(Literal(42.0) > 'b, Literal(42.0) > Cast('b, DoubleType)) - assertEquivalent(Literal(42.0) >= 'b, Literal(42.0) >= Cast('b, DoubleType)) - assertEquivalent(Literal(42.0) === 'b, Literal(42.0) === Cast('b, DoubleType)) - assertEquivalent(Literal(42.0) <=> 'b, Literal(42.0) <=> Cast('b, DoubleType)) - assertEquivalent(Literal(42.0) <= 'b, Literal(42.0) <= Cast('b, DoubleType)) - assertEquivalent(Literal(42.0) < 'b, Literal(42.0) < Cast('b, DoubleType)) + Seq( + Cast('b, DoubleType) > 42.0, + Cast('b, DoubleType) >= 42.0, + Cast('b, DoubleType) === 42.0, + Cast('b, DoubleType) <=> 42.0, + Cast('b, DoubleType) <= 42.0, + Cast('b, DoubleType) < 42.0, + Literal(42.0) > Cast('b, DoubleType), + Literal(42.0) >= Cast('b, DoubleType), + Literal(42.0) === Cast('b, DoubleType), + Literal(42.0) <=> Cast('b, DoubleType), + Literal(42.0) <= Cast('b, DoubleType), + Literal(42.0) < Cast('b, DoubleType), + ).foreach(e => + assertEquivalent(e, e, evaluate = false) + ) } test("unwrap cast should skip when expression is non-deterministic") { Seq(positiveInt, negativeInt).foreach (v => { - val e = Cast(First('a, ignoreNulls = true), IntegerType) <=> v - assertEquivalent(e, e) + val e = Cast(First(f, ignoreNulls = true), IntegerType) <=> v + assertEquivalent(e, e, evaluate = false) }) } test("unwrap casts when literal is null") { val intLit = Literal.create(null, IntegerType) val nullLit = Literal.create(null, BooleanType) - assertEquivalent('a > intLit, nullLit) - assertEquivalent('a >= intLit, nullLit) - assertEquivalent('a === intLit, nullLit) - assertEquivalent('a <=> intLit, IsNull(Cast('a, IntegerType))) - assertEquivalent('a <= intLit, nullLit) - assertEquivalent('a < intLit, nullLit) + assertEquivalent(castInt(f) > intLit, nullLit) + assertEquivalent(castInt(f) >= intLit, nullLit) + assertEquivalent(castInt(f) === intLit, nullLit) + assertEquivalent(castInt(f) <=> intLit, IsNull(castInt(f))) + assertEquivalent(castInt(f) <= intLit, nullLit) + assertEquivalent(castInt(f) < intLit, nullLit) } - private def assertEquivalent(e1: Expression, e2: Expression): Unit = { + test("unwrap cast should skip if cannot coerce type") { + assertEquivalent(Cast(f, ByteType) > 100.toByte, Cast(f, ByteType) > 100.toByte) + } + + private def castInt(f: BoundReference): Expression = Cast(f, IntegerType) + + private def assertEquivalent(e1: Expression, e2: Expression, evaluate: Boolean = true): Unit = { val plan = testRelation.where(e1).analyze val actual = Optimize.execute(plan) val expected = testRelation.where(e2).analyze comparePlans(actual, expected) + + if (evaluate) { + Seq(100.toShort, -300.toShort, null).foreach(v => { + val row = create_row(v) + checkEvaluation(e1, e2.eval(row), row) + }) + } } } From 1f87c37a82bded5c4c9d3e38b7210b9114bba877 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 10 Sep 2020 14:48:42 -0700 Subject: [PATCH 14/17] Remove implicit --- .../UnwrapCastInBinaryComparison.scala | 63 ++++++++++--------- .../UnwrapCastInBinaryComparisonSuite.scala | 62 +++++++++--------- 2 files changed, 67 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index b845fe7b6de7b..d3b9e4024f57e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.types._ * - `cast(fromExp, toType) >= value` ==> if(isnull(fromExp), null, false) * - `cast(fromExp, toType) === value` ==> if(isnull(fromExp), null, false) * - `cast(fromExp, toType) <=> value` ==> false (if `fromExp` is deterministic) - * - `cast(fromExp, toType) <=> value` ==> fromExp <=> cast(value, fromExp) (if `fromExp` is + * - `cast(fromExp, toType) <=> value` ==> cast(fromExp, toType) <=> value (if `fromExp` is * non-deterministic) * - `cast(fromExp, toType) <= value` ==> if(isnull(fromExp), null, true) * - `cast(fromExp, toType) < value` ==> if(isnull(fromExp), null, true) @@ -64,7 +64,7 @@ import org.apache.spark.sql.types._ * - `cast(fromExp, toType) > value` ==> if(isnull(fromExp), null, false) * - `cast(fromExp, toType) >= value` ==> fromExp == max * - `cast(fromExp, toType) === value` ==> fromExp == max - * - `cast(fromExp, toType) <=> value` ==> fromExp == max + * - `cast(fromExp, toType) <=> value` ==> fromExp <=> max * - `cast(fromExp, toType) <= value` ==> if(isnull(fromExp), null, true) * - `cast(fromExp, toType) < value` ==> fromExp =!= max * @@ -86,8 +86,8 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { // Not a canonical form. In this case we first canonicalize the expression by swapping the // literal and cast side, then process the result and swap the literal and cast again to // restore the original order. - case BinaryComparison(Literal(_, toType: IntegralType), Cast(fromExp, _: IntegralType, _)) - if canImplicitlyCast(fromExp, toType) => + case BinaryComparison(Literal(_, literalType), Cast(fromExp, toType, _)) + if canImplicitlyCast(fromExp, toType, literalType) => def swap(e: Expression): Expression = e match { case GreaterThan(left, right) => LessThan(right, left) case GreaterThanOrEqual(left, right) => LessThanOrEqual(right, left) @@ -97,13 +97,14 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { case LessThan(left, right) => GreaterThan(right, left) case _ => e } + swap(unwrapCast(swap(exp))) // In case both sides have integral type, optimize the comparison by removing casts or // moving cast to the literal side. case be @ BinaryComparison( - Cast(fromExp, toType: IntegralType, _), Literal(value, _: IntegralType)) - if canImplicitlyCast(fromExp, toType) => + Cast(fromExp, toType: IntegralType, _), Literal(value, literalType)) + if canImplicitlyCast(fromExp, toType, literalType) => simplifyIntegralComparison(be, fromExp, toType, value) case _ => exp @@ -133,18 +134,18 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { (minCmp.signum, maxCmp.signum, exp) match { case (_, 1, EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _)) => - fromExp.falseIfNotNull + falseIfNotNull(fromExp) case (_, 1, LessThan(_, _) | LessThanOrEqual(_, _)) => - fromExp.trueIfNotNull + trueIfNotNull(fromExp) // make sure the expression is evaluated if it is non-deterministic case (_, 1, EqualNullSafe(_, _)) if exp.deterministic => FalseLiteral case (_, 1, _) => exp case (_, 0, GreaterThan(_, _)) => - fromExp.falseIfNotNull + falseIfNotNull(fromExp) case (_, 0, LessThanOrEqual(_, _)) => - fromExp.trueIfNotNull + trueIfNotNull(fromExp) case (_, 0, LessThan(_, _)) => Not(EqualTo(fromExp, Literal(max, fromType))) case (_, 0, GreaterThanOrEqual(_, _) | EqualTo(_, _)) => @@ -154,18 +155,18 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { case (_, 0, _) => exp case (-1, _, GreaterThan(_, _) | GreaterThanOrEqual(_, _)) => - fromExp.trueIfNotNull + trueIfNotNull(fromExp) case (-1, _, LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _)) => - fromExp.falseIfNotNull + falseIfNotNull(fromExp) // make sure the expression is evaluated if it is non-deterministic case (-1, _, EqualNullSafe(_, _)) if exp.deterministic => FalseLiteral case (-1, _, _) => exp case (0, _, LessThan(_, _)) => - fromExp.falseIfNotNull + falseIfNotNull(fromExp) case (0, _, GreaterThanOrEqual(_, _)) => - fromExp.trueIfNotNull + trueIfNotNull(fromExp) case (0, _, GreaterThan(_, _)) => Not(EqualTo(fromExp, Literal(min, fromType))) case (0, _, LessThanOrEqual(_, _) | EqualTo(_, _)) => @@ -189,8 +190,11 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { * i.e., the conversion is injective. Note this only handles the case when both sides are of * integral type. */ - private def canImplicitlyCast(fromExp: Expression, toType: DataType): Boolean = { - fromExp.dataType.isInstanceOf[IntegralType] && toType.isInstanceOf[IntegralType] && + private def canImplicitlyCast(fromExp: Expression, toType: DataType, + literalType: DataType): Boolean = { + toType.sameType(literalType) && + fromExp.dataType.isInstanceOf[IntegralType] && + toType.isInstanceOf[IntegralType] && Cast.canUpCast(fromExp.dataType, toType) } @@ -199,19 +203,22 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { case ShortType => (Short.MinValue, Short.MaxValue) case IntegerType => (Int.MinValue, Int.MaxValue) case LongType => (Long.MinValue, Long.MaxValue) + case other => throw new IllegalArgumentException(s"Unsupported type: ${other.catalogString}") } - private[optimizer] implicit class ExpressionWrapper(e: Expression) { - /** - * Wraps input expression `e` with `if(isnull(e), null, false)`. The if-clause is represented - * using `and(isnull(e), null)` which is semantically equivalent by applying 3-valued logic. - */ - def falseIfNotNull: Expression = And(IsNull(e), Literal(null, BooleanType)) - - /** - * Wraps input expression `e` with `if(isnull(e), null, true)`. The if-clause is represented - * using `or(isnotnull(e), null)` which is semantically equivalent by applying 3-valued logic. - */ - def trueIfNotNull: Expression = Or(IsNotNull(e), Literal(null, BooleanType)) + /** + * Wraps input expression `e` with `if(isnull(e), null, false)`. The if-clause is represented + * using `and(isnull(e), null)` which is semantically equivalent by applying 3-valued logic. + */ + private[optimizer] def falseIfNotNull(e: Expression): Expression = { + And(IsNull(e), Literal(null, BooleanType)) + } + + /** + * Wraps input expression `e` with `if(isnull(e), null, true)`. The if-clause is represented + * using `or(isnotnull(e), null)` which is semantically equivalent by applying 3-valued logic. + */ + private[optimizer] def trueIfNotNull(e: Expression): Expression = { + Or(IsNotNull(e), Literal(null, BooleanType)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 4ee63be442e27..387964088b808 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -37,46 +37,46 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp } val testRelation: LocalRelation = LocalRelation('a.short, 'b.float) - val f = 'a.short.canBeNull.at(0) + val f: BoundReference = 'a.short.canBeNull.at(0) test("unwrap casts when literal == max") { val v = Short.MaxValue - assertEquivalent(castInt(f) > v.toInt, f.falseIfNotNull) + assertEquivalent(castInt(f) > v.toInt, falseIfNotNull(f)) assertEquivalent(castInt(f) >= v.toInt, f === v) assertEquivalent(castInt(f) === v.toInt, f === v) assertEquivalent(castInt(f) <=> v.toInt, f <=> v) - assertEquivalent(castInt(f) <= v.toInt, f.trueIfNotNull) + assertEquivalent(castInt(f) <= v.toInt, trueIfNotNull(f)) assertEquivalent(castInt(f) < v.toInt, f =!= v) } test("unwrap casts when literal > max") { val v: Int = positiveInt - assertEquivalent(castInt(f) > v, f.falseIfNotNull) - assertEquivalent(castInt(f) >= v, f.falseIfNotNull) - assertEquivalent(castInt(f) === v, f.falseIfNotNull) + assertEquivalent(castInt(f) > v, falseIfNotNull(f)) + assertEquivalent(castInt(f) >= v, falseIfNotNull(f)) + assertEquivalent(castInt(f) === v, falseIfNotNull(f)) assertEquivalent(castInt(f) <=> v, false) - assertEquivalent(castInt(f) <= v, f.trueIfNotNull) - assertEquivalent(castInt(f) < v, f.trueIfNotNull) + assertEquivalent(castInt(f) <= v, trueIfNotNull(f)) + assertEquivalent(castInt(f) < v, trueIfNotNull(f)) } test("unwrap casts when literal == min") { val v = Short.MinValue assertEquivalent(castInt(f) > v.toInt, f =!= v) - assertEquivalent(castInt(f) >= v.toInt, f.trueIfNotNull) + assertEquivalent(castInt(f) >= v.toInt, trueIfNotNull(f)) assertEquivalent(castInt(f) === v.toInt, f === v) assertEquivalent(castInt(f) <=> v.toInt, f <=> v) assertEquivalent(castInt(f) <= v.toInt, f === v) - assertEquivalent(castInt(f) < v.toInt, f.falseIfNotNull) + assertEquivalent(castInt(f) < v.toInt, falseIfNotNull(f)) } test("unwrap casts when literal < min") { val v: Int = negativeInt - assertEquivalent(castInt(f) > v, f.trueIfNotNull) - assertEquivalent(castInt(f) >= v, f.trueIfNotNull) - assertEquivalent(castInt(f) === v, f.falseIfNotNull) + assertEquivalent(castInt(f) > v, trueIfNotNull(f)) + assertEquivalent(castInt(f) >= v, trueIfNotNull(f)) + assertEquivalent(castInt(f) === v, falseIfNotNull(f)) assertEquivalent(castInt(f) <=> v, false) - assertEquivalent(castInt(f) <= v, f.falseIfNotNull) - assertEquivalent(castInt(f) < v, f.falseIfNotNull) + assertEquivalent(castInt(f) <= v, falseIfNotNull(f)) + assertEquivalent(castInt(f) < v, falseIfNotNull(f)) } test("unwrap casts when literal is within range (min, max)") { @@ -90,11 +90,11 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp test("unwrap casts when cast is on rhs") { val v = Short.MaxValue - assertEquivalent(Literal(v.toInt) < castInt(f), f.falseIfNotNull) + assertEquivalent(Literal(v.toInt) < castInt(f), falseIfNotNull(f)) assertEquivalent(Literal(v.toInt) <= castInt(f), Literal(v) === f) assertEquivalent(Literal(v.toInt) === castInt(f), Literal(v) === f) assertEquivalent(Literal(v.toInt) <=> castInt(f), Literal(v) <=> f) - assertEquivalent(Literal(v.toInt) >= castInt(f), f.trueIfNotNull) + assertEquivalent(Literal(v.toInt) >= castInt(f), trueIfNotNull(f)) assertEquivalent(Literal(v.toInt) > castInt(f), f =!= v) assertEquivalent(Literal(30) <= castInt(f), Literal(30.toShort) <= f) @@ -102,18 +102,18 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp test("unwrap cast should have no effect when input is not integral type") { Seq( - Cast('b, DoubleType) > 42.0, - Cast('b, DoubleType) >= 42.0, - Cast('b, DoubleType) === 42.0, - Cast('b, DoubleType) <=> 42.0, - Cast('b, DoubleType) <= 42.0, - Cast('b, DoubleType) < 42.0, - Literal(42.0) > Cast('b, DoubleType), - Literal(42.0) >= Cast('b, DoubleType), - Literal(42.0) === Cast('b, DoubleType), - Literal(42.0) <=> Cast('b, DoubleType), - Literal(42.0) <= Cast('b, DoubleType), - Literal(42.0) < Cast('b, DoubleType), + castDouble('b) > 42.0, + castDouble('b) >= 42.0, + castDouble('b) === 42.0, + castDouble('b) <=> 42.0, + castDouble('b) <= 42.0, + castDouble('b) < 42.0, + Literal(42.0) > castDouble('b), + Literal(42.0) >= castDouble('b), + Literal(42.0) === castDouble('b), + Literal(42.0) <=> castDouble('b), + Literal(42.0) <= castDouble('b), + Literal(42.0) < castDouble('b) ).foreach(e => assertEquivalent(e, e, evaluate = false) ) @@ -141,7 +141,9 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp assertEquivalent(Cast(f, ByteType) > 100.toByte, Cast(f, ByteType) > 100.toByte) } - private def castInt(f: BoundReference): Expression = Cast(f, IntegerType) + private def castInt(e: Expression): Expression = Cast(e, IntegerType) + + private def castDouble(e: Expression): Expression = Cast(e, DoubleType) private def assertEquivalent(e1: Expression, e2: Expression, evaluate: Boolean = true): Unit = { val plan = testRelation.where(e1).analyze From ec88961ff3b48388d8a7847fbdc57d061fc52ff2 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 10 Sep 2020 17:25:25 -0700 Subject: [PATCH 15/17] Oops forgot to update another test --- .../org/apache/spark/sql/FileBasedDataSourceSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 5eae588a7b7f8..48b2e22457e3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -924,8 +924,8 @@ class FileBasedDataSourceSuite extends QueryTest sources.EqualTo("id", v))) checkPushedFilters(format, df.where('id === v.toInt), Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters(format, df.where('id <=> v.toInt), Array(sources.IsNotNull("id"), - sources.EqualTo("id", v))) + checkPushedFilters(format, df.where('id <=> v.toInt), + Array(sources.EqualNullSafe("id", v))) checkPushedFilters(format, df.where('id <= v.toInt), Array(sources.IsNotNull("id"))) checkPushedFilters(format, df.where('id < v.toInt), Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) @@ -946,6 +946,8 @@ class FileBasedDataSourceSuite extends QueryTest checkPushedFilters(format, df.where(lit(v.toInt) <= 'id), Array(sources.IsNotNull("id"))) checkPushedFilters(format, df.where(lit(v.toInt) === 'id), Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters(format, df.where(lit(v.toInt) <=> 'id), + Array(sources.EqualNullSafe("id", v))) checkPushedFilters(format, df.where(lit(v.toInt) >= 'id), Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) checkPushedFilters(format, df.where(lit(v.toInt) > 'id), Array(), noScan = true) From de911da549fe7c39aed2b2ae92a100d9c1fff2ed Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Fri, 11 Sep 2020 09:57:38 -0700 Subject: [PATCH 16/17] Revert "Switch to pattern matching" This reverts commit 265c1698c29730b0dee2865773557d7a3c4c1144. --- .../UnwrapCastInBinaryComparison.scala | 114 ++++++++++-------- 1 file changed, 63 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index d3b9e4024f57e..89f7c0f71b7ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -130,58 +130,70 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { val ordering = toType.ordering.asInstanceOf[Ordering[Any]] val minCmp = ordering.compare(value, minInToType) val maxCmp = ordering.compare(value, maxInToType) - val lit = Cast(Literal(value), fromType) - (minCmp.signum, maxCmp.signum, exp) match { - case (_, 1, EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _)) => - falseIfNotNull(fromExp) - case (_, 1, LessThan(_, _) | LessThanOrEqual(_, _)) => - trueIfNotNull(fromExp) - // make sure the expression is evaluated if it is non-deterministic - case (_, 1, EqualNullSafe(_, _)) if exp.deterministic => - FalseLiteral - case (_, 1, _) => exp - - case (_, 0, GreaterThan(_, _)) => - falseIfNotNull(fromExp) - case (_, 0, LessThanOrEqual(_, _)) => - trueIfNotNull(fromExp) - case (_, 0, LessThan(_, _)) => - Not(EqualTo(fromExp, Literal(max, fromType))) - case (_, 0, GreaterThanOrEqual(_, _) | EqualTo(_, _)) => - EqualTo(fromExp, Literal(max, fromType)) - case (_, 0, EqualNullSafe(_, _)) => - EqualNullSafe(fromExp, Literal(max, fromType)) - case (_, 0, _) => exp - - case (-1, _, GreaterThan(_, _) | GreaterThanOrEqual(_, _)) => - trueIfNotNull(fromExp) - case (-1, _, LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _)) => - falseIfNotNull(fromExp) - // make sure the expression is evaluated if it is non-deterministic - case (-1, _, EqualNullSafe(_, _)) if exp.deterministic => - FalseLiteral - case (-1, _, _) => exp - - case (0, _, LessThan(_, _)) => - falseIfNotNull(fromExp) - case (0, _, GreaterThanOrEqual(_, _)) => - trueIfNotNull(fromExp) - case (0, _, GreaterThan(_, _)) => - Not(EqualTo(fromExp, Literal(min, fromType))) - case (0, _, LessThanOrEqual(_, _) | EqualTo(_, _)) => - EqualTo(fromExp, Literal(min, fromType)) - case (0, _, EqualNullSafe(_, _)) => - EqualNullSafe(fromExp, Literal(min, fromType)) - case (0, _, _) => exp - - case (_, _, GreaterThan(_, _)) => GreaterThan(fromExp, lit) - case (_, _, GreaterThanOrEqual(_, _)) => GreaterThanOrEqual(fromExp, lit) - case (_, _, EqualTo(_, _)) => EqualTo(fromExp, lit) - case (_, _, EqualNullSafe(_, _)) => EqualNullSafe(fromExp, lit) - case (_, _, LessThan(_, _)) => LessThan(fromExp, lit) - case (_, _, LessThanOrEqual(_, _)) => LessThanOrEqual(fromExp, lit) - case _ => exp + if (maxCmp > 0) { + exp match { + case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) => + falseIfNotNull(fromExp) + case LessThan(_, _) | LessThanOrEqual(_, _) => + trueIfNotNull(fromExp) + // make sure the expression is evaluated if it is non-deterministic + case EqualNullSafe(_, _) if exp.deterministic => + FalseLiteral + case _ => exp + } + } else if (maxCmp == 0) { + exp match { + case GreaterThan(_, _) => + falseIfNotNull(fromExp) + case LessThanOrEqual(_, _) => + trueIfNotNull(fromExp) + case LessThan(_, _) => + Not(EqualTo(fromExp, Literal(max, fromType))) + case GreaterThanOrEqual(_, _) | EqualTo(_, _) => + EqualTo(fromExp, Literal(max, fromType)) + case EqualNullSafe(_, _) => + EqualNullSafe(fromExp, Literal(max, fromType)) + case _ => exp + } + } else if (minCmp < 0) { + exp match { + case GreaterThan(_, _) | GreaterThanOrEqual(_, _) => + trueIfNotNull(fromExp) + case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) => + falseIfNotNull(fromExp) + // make sure the expression is evaluated if it is non-deterministic + case EqualNullSafe(_, _) if exp.deterministic => + FalseLiteral + case _ => exp + } + } else if (minCmp == 0) { + exp match { + case LessThan(_, _) => + falseIfNotNull(fromExp) + case GreaterThanOrEqual(_, _) => + trueIfNotNull(fromExp) + case GreaterThan(_, _) => + Not(EqualTo(fromExp, Literal(min, fromType))) + case LessThanOrEqual(_, _) | EqualTo(_, _) => + EqualTo(fromExp, Literal(min, fromType)) + case EqualNullSafe(_, _) => + EqualNullSafe(fromExp, Literal(min, fromType)) + case _ => exp + } + } else { + // This means `value` is within range `(min, max)`. Optimize this by moving the cast to the + // literal side. + val lit = Cast(Literal(value), fromType) + exp match { + case GreaterThan(_, _) => GreaterThan(fromExp, lit) + case GreaterThanOrEqual(_, _) => GreaterThanOrEqual(fromExp, lit) + case EqualTo(_, _) => EqualTo(fromExp, lit) + case EqualNullSafe(_, _) => EqualNullSafe(fromExp, lit) + case LessThan(_, _) => LessThan(fromExp, lit) + case LessThanOrEqual(_, _) => LessThanOrEqual(fromExp, lit) + case _ => exp + } } } From 5f32ee5fda08f4271959589c7f6312195898af08 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Fri, 11 Sep 2020 13:44:39 -0700 Subject: [PATCH 17/17] Re-trigger CI