@@ -42,6 +42,7 @@ object DefaultOptimizer extends Optimizer {
4242 NullPropagation ,
4343 ConstantFolding ,
4444 LikeSimplification ,
45+ ConditionSimplification ,
4546 BooleanSimplification ,
4647 SimplifyFilters ,
4748 SimplifyCasts ,
@@ -302,7 +303,8 @@ object OptimizeIn extends Rule[LogicalPlan] {
302303object ConditionSimplification extends Rule [LogicalPlan ] with PredicateHelper {
303304
304305 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
305- case q : LogicalPlan => q transformExpressionsDown {
306+ case q : LogicalPlan => q transformExpressionsUp {
307+ /** 1. one And/Or with same condition. */
306308 // a && a => a
307309 case And (left, right) if left.fastEquals(right) =>
308310 left
@@ -311,85 +313,139 @@ object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper {
311313 case Or (left, right) if left.fastEquals(right) =>
312314 left
313315
316+ /** 2. one And/Or with literal conditions that can be merged. */
314317 // a < 2 && a > 2 => false, a > 3 && a > 5 => a > 5
315318 case and @ And (
316- e1 @ NumericLiteralBinaryComparison (n1, i1),
317- e2 @ NumericLiteralBinaryComparison (n2, i2)) if n1 == n2 =>
319+ e1 @ NumLitBinComparison (n1, i1),
320+ e2 @ NumLitBinComparison (n2, i2)) if n1 == n2 =>
318321 if (! i1.intersects(i2)) Literal (false )
319322 else if (i1.isSubsetOf(i2)) e1
320323 else if (i1.isSupersetOf(i2)) e2
324+ else if (i1.intersect(i2).isPoint)
325+ EqualTo (n1, Literal (i1.intersect(i2).asInstanceOf [Point [Double ]].value, n1.dataType))
321326 else and
322327
323328 // a < 2 || a >= 2 => true, a > 3 || a > 5 => a > 3
324329 case or @ Or (
325- e1 @ NumericLiteralBinaryComparison (n1, i1),
326- e2 @ NumericLiteralBinaryComparison (n2, i2)) if n1 == n2 =>
327- if (i1.intersects(i2)) Literal (true )
330+ e1 @ NumLitBinComparison (n1, i1),
331+ e2 @ NumLitBinComparison (n2, i2)) if n1 == n2 =>
332+ // a hack to avoid bug of spire
333+ val op = Interval .all[Double ] -- i1
334+ if (i1.intersects(i2) && i1.union(i2) == Interval .all[Double ]) Literal (true )
335+ else if (op(op.size - 1 ) == i2) Literal (true )
328336 else if (i1.isSubsetOf(i2)) e2
329337 else if (i1.isSupersetOf(i2)) e1
330338 else or
331339
332- // (a < 3 && b > 5) || a > 2 => b > 5 || a > 2
333- case Or (left1 @ And (left2, right2), right1) =>
334- And (Or (left2, right1), Or (right2, right1))
335-
340+ /** 3. Two And/Or with literal condition that can be merged, do a transformation to reuse 2. */
336341 // (a < 3 || b > 5) || a > 2 => true, (b > 5 || a < 3) || a > 2 => true
337- case Or ( Or (
338- e1 @ NumericLiteralBinaryComparison (n1, i1), e2 @ NumericLiteralBinaryComparison (n2, i2)),
339- right @ NumericLiteralBinaryComparison (n3, i3)) =>
342+ case or @ Or (
343+ Or ( e1 @ NumLitBinComparison (n1, i1), e2 @ NumLitBinComparison (n2, i2)),
344+ right @ NumLitBinComparison (n3, i3)) =>
340345 if (n3 fastEquals n1) {
341346 Or (Or (e1, right), e2)
342- } else {
347+ } else if (n3 fastEquals n2) {
343348 Or (Or (e2, right), e1)
349+ } else {
350+ or
344351 }
345352
346353 // (b > 5 && a < 2) && a > 3 => false, (a < 2 && b > 5) && a > 3 => false
347- case And ( And (
348- e1 @ NumericLiteralBinaryComparison (n1, i1), e2 @ NumericLiteralBinaryComparison (n2, i2)),
349- right @ NumericLiteralBinaryComparison (n3, i3)) =>
354+ case and @ And (
355+ And ( e1 @ NumLitBinComparison (n1, i1), e2 @ NumLitBinComparison (n2, i2)),
356+ right @ NumLitBinComparison (n3, i3)) =>
350357 if (n3 fastEquals n1) {
351358 And (And (e1, right), e2)
352- } else {
359+ } else if (n3 fastEquals n2) {
353360 And (And (e2, right), e1)
361+ } else {
362+ and
354363 }
355364
356365 // (a < 2 || b > 5) && a > 3 => b > 5 && a > 3
357- case And (left1@ Or (left2, right2), right1) =>
358- Or (And (left2, right1), And (right2, right1))
359-
360- // (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... =>
366+ // using formula: a && (b || c) = (a && b) || (a && c)
367+ case And (
368+ left1 @ Or (left2 @ NumLitBinComparison (n1, i1), right2 @ NumLitBinComparison (n2, i2)),
369+ right1 @ NumLitBinComparison (n3, i3))
370+ if ((n3.fastEquals(n1) && i3 != i1) || (n3.fastEquals(n2) && i3 != i2)) =>
371+ Or (And (left2, right1), And (right2, right1))
372+
373+ // (a < 3 && b > 5) || a > 2 => b > 5 || a > 2.
374+ // using formula: a || (b && c) = (a || b) && (a || c)
375+ case Or (
376+ left1 @ And (left2 @ NumLitBinComparison (n1, i1), right2 @ NumLitBinComparison (n2, i2)),
377+ right1 @ NumLitBinComparison (n3, i3))
378+ if ((n3.fastEquals(n1) && i3 != i1) || (n3.fastEquals(n2) && i3 != i2)) =>
379+ Or (And (left2, right1), And (right2, right1))
380+
381+ /** 4. And/Or whose one child is literal condition, the other is Or/And */
382+ // (a < 2 || b > 5) && a < 2 => a < 2
383+ case And (
384+ left1 @ Or (left2 @ NumLitBinComparison (_, _), right2 @ NumLitBinComparison (_, _)),
385+ right1 @ NumLitBinComparison (_, _))
386+ if (right1 fastEquals left2) || (right1 fastEquals right2) =>
387+ right1
388+
389+ // (a < 3 && b > 5) || a < 3 => a < 3
390+ case Or (
391+ left1 @ And (left2 @ NumLitBinComparison (_, _), right2 @ NumLitBinComparison (_, _)),
392+ right1 @ NumLitBinComparison (_, _))
393+ if (right1 fastEquals left2) || (right1 fastEquals right2) =>
394+ right1
395+
396+ // 5. (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... =>
361397 // a && b && ((c && ...) || (d && ...) || (e && ...) || ...)
362398 case or @ Or (left, right) =>
363399 val lhsSet = splitConjunctivePredicates(left).toSet
364400 val rhsSet = splitConjunctivePredicates(right).toSet
365401 val common = lhsSet.intersect(rhsSet)
366- (lhsSet.diff(common).reduceOption(And ) ++ rhsSet.diff(common).reduceOption(And ))
367- .reduceOption(Or )
368- .map(_ :: common.toList)
369- .getOrElse(common.toList)
370- .reduce(And )
402+ val ldiff = lhsSet.diff(common)
403+ val rdiff = rhsSet.diff(common)
404+ if (common.size == 0 ) {
405+ or
406+ }else if ( ldiff.size == 0 || rdiff == 0 ) {
407+ common.reduce(And )
408+ } else {
409+ (ldiff.reduceOption(And ) ++ rdiff.reduceOption(And ))
410+ .reduceOption(Or )
411+ .map(_ :: common.toList)
412+ .getOrElse(common.toList)
413+ .reduce(And )
414+ }
371415
372- // (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... =>
416+ // 6. (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... =>
373417 // (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...)
374418 case and @ And (left, right) =>
375419 val lhsSet = splitDisjunctivePredicates(left).toSet
376420 val rhsSet = splitDisjunctivePredicates(right).toSet
377421 val common = lhsSet.intersect(rhsSet)
378- (lhsSet.diff(common).reduceOption(Or ) ++ rhsSet.diff(common).reduceOption(Or ))
379- .reduceOption(And )
380- .map(_ :: common.toList)
381- .getOrElse(common.toList)
382- .reduce(Or )
422+ val ldiff = lhsSet.diff(common)
423+ val rdiff = rhsSet.diff(common)
424+
425+ if (common.size == 0 ) {
426+ and
427+ }else if (ldiff.size == 0 || rdiff.size == 0 ) {
428+ common.reduce(Or )
429+ } else {
430+ val x = (ldiff.reduceOption(Or ) ++ rdiff.reduceOption(Or ))
431+ .reduceOption(And )
432+ .map(_ :: common.toList)
433+ .getOrElse(common.toList)
434+ .reduce(Or )
435+ x
436+ }
437+
438+ case other => other
383439 }
384440 }
385441
386442 private implicit class NumericLiteral (e : Literal ) {
387443 def toDouble = Cast (e, DoubleType ).eval().asInstanceOf [Double ]
388444 }
389445
390- object NumericLiteralBinaryComparison {
446+ object NumLitBinComparison {
391447 def unapply (e : Expression ): Option [(NamedExpression , Interval [Double ])] = e match {
392- case LessThan (n : NamedExpression , l @ Literal (_, _ : NumericType )) => Some ((n, Interval .below(l.toDouble)))
448+ case LessThan (n : NamedExpression , l @ Literal (_, _ : NumericType )) => Some ((n, Interval .below(l.toDouble)))
393449 case LessThan (l @ Literal (_, _ : NumericType ), n : NamedExpression ) => Some ((n, Interval .atOrAbove(l.toDouble)))
394450
395451 case GreaterThan (n : NamedExpression , l @ Literal (_, _ : NumericType )) => Some ((n, Interval .above(l.toDouble)))
@@ -402,6 +458,7 @@ object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper {
402458 case GreaterThanOrEqual (l @ Literal (_, _ : NumericType ), n : NamedExpression ) => Some ((n, Interval .below(l.toDouble)))
403459
404460 case EqualTo (n : NamedExpression , l @ Literal (_, _ : NumericType )) => Some ((n, Interval .point(l.toDouble)))
461+ case other => None
405462 }
406463 }
407464}
0 commit comments