diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 9600357f43cc9..53bb0e3a527cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -104,8 +104,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And), child) + Filter(splitConjunctivePredicates(condition).map(rewriteBinaryComparison) + .sortBy(_.hashCode()).reduce(And), child) case sample: Sample => sample.copy(seed = 0L) case Join(left, right, joinType, condition, hint) if condition.isDefined => @@ -117,23 +117,26 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => } val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And) + splitConjunctivePredicates(condition.get).map(rewriteBinaryComparison) + .sortBy(_.hashCode()).reduce(And) Join(left, right, newJoinType, Some(newCondition), hint) } } /** - * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * Rewrite [[BinaryComparison]] operator to keep order. The following cases will be * equivalent: * 1. (a = b), (b = a); * 2. (a <=> b), (b <=> a). + * 3. (a > b), (b < a) */ - private def rewriteEqual(condition: Expression): Expression = condition match { - case eq @ EqualTo(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) - case eq @ EqualNullSafe(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + private def rewriteBinaryComparison(condition: Expression): Expression = condition match { + case EqualTo(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case EqualNullSafe(l, r) => Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) + case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) case _ => condition // Don't reorder. }