Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
} else {
// This means `value` is within range `(min, max)`. Optimize this by moving the cast to the
// literal side.
val lit = Cast(Literal(value), fromType)
val lit = Literal(Cast(Literal(value), fromType).eval(), fromType)
exp match {
case GreaterThan(_, _) => GreaterThan(fromExp, lit)
case GreaterThanOrEqual(_, _) => GreaterThanOrEqual(fromExp, lit)
Expand All @@ -202,9 +202,12 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] {
* i.e., the conversion is injective. Note this only handles the case when both sides are of
* integral type.
*/
private def canImplicitlyCast(fromExp: Expression, toType: DataType,
private def canImplicitlyCast(
fromExp: Expression,
toType: DataType,
literalType: DataType): Boolean = {
toType.sameType(literalType) &&
!fromExp.foldable &&
fromExp.dataType.isInstanceOf[IntegralType] &&
toType.isInstanceOf[IntegralType] &&
Cast.canUpCast(fromExp.dataType, toType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ import org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, IntegerType}
import org.apache.spark.sql.types._

class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches: List[Batch] =
Batch("Unwrap casts in binary comparison", FixedPoint(10),
NullPropagation, ConstantFolding, UnwrapCastInBinaryComparison) :: Nil
NullPropagation, UnwrapCastInBinaryComparison) :: Nil
}

val testRelation: LocalRelation = LocalRelation('a.short, 'b.float)
Expand Down Expand Up @@ -97,7 +97,7 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
assertEquivalent(Literal(v.toInt) >= castInt(f), trueIfNotNull(f))
assertEquivalent(Literal(v.toInt) > castInt(f), f =!= v)

assertEquivalent(Literal(30) <= castInt(f), Literal(30.toShort) <= f)
assertEquivalent(Literal(30) <= castInt(f), Literal(30.toShort, ShortType) <= f)
}

test("unwrap cast should have no effect when input is not integral type") {
Expand All @@ -119,10 +119,12 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp
)
}

test("unwrap cast should skip when expression is non-deterministic") {
test("unwrap cast should skip when expression is non-deterministic or foldable") {
Seq(positiveInt, negativeInt).foreach (v => {
val e = Cast(First(f, ignoreNulls = true), IntegerType) <=> v
assertEquivalent(e, e, evaluate = false)
val e2 = Cast(Literal(30.toShort), IntegerType) >= v
assertEquivalent(e2, e2, evaluate = false)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.SimpleMode
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, FileScan}
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.{ParquetScan, ParquetTable}
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down