diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 89f7c0f71b7ac..d0acfe036d443 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -184,7 +184,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { } else { // This means `value` is within range `(min, max)`. Optimize this by moving the cast to the // literal side. - val lit = Cast(Literal(value), fromType) + val lit = Literal(Cast(Literal(value), fromType).eval(), fromType) exp match { case GreaterThan(_, _) => GreaterThan(fromExp, lit) case GreaterThanOrEqual(_, _) => GreaterThanOrEqual(fromExp, lit) @@ -202,9 +202,12 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { * i.e., the conversion is injective. Note this only handles the case when both sides are of * integral type. */ - private def canImplicitlyCast(fromExp: Expression, toType: DataType, + private def canImplicitlyCast( + fromExp: Expression, + toType: DataType, literalType: DataType): Boolean = { toType.sameType(literalType) && + !fromExp.foldable && fromExp.dataType.isInstanceOf[IntegralType] && toType.isInstanceOf[IntegralType] && Cast.canUpCast(fromExp.dataType, toType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 387964088b808..373c1febd2488 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -26,14 +26,14 @@ import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, IntegerType} +import org.apache.spark.sql.types._ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelper { object Optimize extends RuleExecutor[LogicalPlan] { val batches: List[Batch] = Batch("Unwrap casts in binary comparison", FixedPoint(10), - NullPropagation, ConstantFolding, UnwrapCastInBinaryComparison) :: Nil + NullPropagation, UnwrapCastInBinaryComparison) :: Nil } val testRelation: LocalRelation = LocalRelation('a.short, 'b.float) @@ -97,7 +97,7 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp assertEquivalent(Literal(v.toInt) >= castInt(f), trueIfNotNull(f)) assertEquivalent(Literal(v.toInt) > castInt(f), f =!= v) - assertEquivalent(Literal(30) <= castInt(f), Literal(30.toShort) <= f) + assertEquivalent(Literal(30) <= castInt(f), Literal(30.toShort, ShortType) <= f) } test("unwrap cast should have no effect when input is not integral type") { @@ -119,10 +119,12 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp ) } - test("unwrap cast should skip when expression is non-deterministic") { + test("unwrap cast should skip when expression is non-deterministic or foldable") { Seq(positiveInt, negativeInt).foreach (v => { val e = Cast(First(f, ignoreNulls = true), IntegerType) <=> v assertEquivalent(e, e, evaluate = false) + val e2 = Cast(Literal(30.toShort), IntegerType) >= v + assertEquivalent(e2, e2, evaluate = false) }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 8d6d93d13d143..f72e3347510f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -32,14 +32,13 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.SimpleMode import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FilePartition -import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, FileScan} +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan -import org.apache.spark.sql.execution.datasources.v2.parquet.{ParquetScan, ParquetTable} +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf