Skip to content

Commit 32a595b

Browse files
committed
improvement and test fix
1 parent e99a26c commit 32a595b

File tree

3 files changed

+120
-62
lines changed

3 files changed

+120
-62
lines changed

pom.xml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@
156156
<id>central</id>
157157
<!-- This should be at top, it makes maven try the central repo first and then others and hence faster dep resolution -->
158158
<name>Maven Repository</name>
159-
<url>https://repo1.maven.org/maven2</url>
159+
<url>http://repo1.maven.org/maven2</url>
160160
<releases>
161161
<enabled>true</enabled>
162162
</releases>
@@ -167,7 +167,7 @@
167167
<repository>
168168
<id>apache-repo</id>
169169
<name>Apache Repository</name>
170-
<url>https://repository.apache.org/content/repositories/releases</url>
170+
<url>http://repository.apache.org/content/repositories/releases</url>
171171
<releases>
172172
<enabled>true</enabled>
173173
</releases>
@@ -178,7 +178,7 @@
178178
<repository>
179179
<id>jboss-repo</id>
180180
<name>JBoss Repository</name>
181-
<url>https://repository.jboss.org/nexus/content/repositories/releases</url>
181+
<url>http://repository.jboss.org/nexus/content/repositories/releases</url>
182182
<releases>
183183
<enabled>true</enabled>
184184
</releases>
@@ -189,7 +189,7 @@
189189
<repository>
190190
<id>mqtt-repo</id>
191191
<name>MQTT Repository</name>
192-
<url>https://repo.eclipse.org/content/repositories/paho-releases</url>
192+
<url>http://repo.eclipse.org/content/repositories/paho-releases</url>
193193
<releases>
194194
<enabled>true</enabled>
195195
</releases>
@@ -200,7 +200,7 @@
200200
<repository>
201201
<id>cloudera-repo</id>
202202
<name>Cloudera Repository</name>
203-
<url>https://repository.cloudera.com/artifactory/cloudera-repos</url>
203+
<url>http://repository.cloudera.com/artifactory/cloudera-repos</url>
204204
<releases>
205205
<enabled>true</enabled>
206206
</releases>
@@ -222,7 +222,7 @@
222222
<repository>
223223
<id>spring-releases</id>
224224
<name>Spring Release Repository</name>
225-
<url>https://repo.spring.io/libs-release</url>
225+
<url>http://repo.spring.io/libs-release</url>
226226
<releases>
227227
<enabled>true</enabled>
228228
</releases>
@@ -234,7 +234,7 @@
234234
<pluginRepositories>
235235
<pluginRepository>
236236
<id>central</id>
237-
<url>https://repo1.maven.org/maven2</url>
237+
<url>http://repo1.maven.org/maven2</url>
238238
<releases>
239239
<enabled>true</enabled>
240240
</releases>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 92 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ object DefaultOptimizer extends Optimizer {
4242
NullPropagation,
4343
ConstantFolding,
4444
LikeSimplification,
45+
ConditionSimplification,
4546
BooleanSimplification,
4647
SimplifyFilters,
4748
SimplifyCasts,
@@ -302,7 +303,8 @@ object OptimizeIn extends Rule[LogicalPlan] {
302303
object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper {
303304

304305
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
305-
case q: LogicalPlan => q transformExpressionsDown {
306+
case q: LogicalPlan => q transformExpressionsUp {
307+
/** 1. one And/Or with same condition. */
306308
// a && a => a
307309
case And(left, right) if left.fastEquals(right) =>
308310
left
@@ -311,85 +313,139 @@ object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper {
311313
case Or(left, right) if left.fastEquals(right) =>
312314
left
313315

316+
/** 2. one And/Or with literal conditions that can be merged. */
314317
// a < 2 && a > 2 => false, a > 3 && a > 5 => a > 5
315318
case and @ And(
316-
e1 @ NumericLiteralBinaryComparison(n1, i1),
317-
e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 =>
319+
e1 @ NumLitBinComparison(n1, i1),
320+
e2 @ NumLitBinComparison(n2, i2)) if n1 == n2 =>
318321
if (!i1.intersects(i2)) Literal(false)
319322
else if (i1.isSubsetOf(i2)) e1
320323
else if (i1.isSupersetOf(i2)) e2
324+
else if (i1.intersect(i2).isPoint)
325+
EqualTo(n1, Literal(i1.intersect(i2).asInstanceOf[Point[Double]].value, n1.dataType))
321326
else and
322327

323328
// a < 2 || a >= 2 => true, a > 3 || a > 5 => a > 3
324329
case or @ Or(
325-
e1 @ NumericLiteralBinaryComparison(n1, i1),
326-
e2 @ NumericLiteralBinaryComparison(n2, i2)) if n1 == n2 =>
327-
if (i1.intersects(i2)) Literal(true)
330+
e1 @ NumLitBinComparison(n1, i1),
331+
e2 @ NumLitBinComparison(n2, i2)) if n1 == n2 =>
332+
// a hack to avoid bug of spire
333+
val op = Interval.all[Double] -- i1
334+
if (i1.intersects(i2) && i1.union(i2) == Interval.all[Double]) Literal(true)
335+
else if (op(op.size - 1) == i2) Literal(true)
328336
else if (i1.isSubsetOf(i2)) e2
329337
else if (i1.isSupersetOf(i2)) e1
330338
else or
331339

332-
// (a < 3 && b > 5) || a > 2 => b > 5 || a > 2
333-
case Or(left1 @ And(left2, right2), right1) =>
334-
And(Or(left2, right1), Or(right2, right1))
335-
340+
/** 3. Two And/Or with literal condition that can be merged, do a transformation to reuse 2. */
336341
// (a < 3 || b > 5) || a > 2 => true, (b > 5 || a < 3) || a > 2 => true
337-
case Or( Or(
338-
e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)),
339-
right @ NumericLiteralBinaryComparison(n3, i3)) =>
342+
case or @ Or(
343+
Or(e1 @ NumLitBinComparison(n1, i1), e2 @ NumLitBinComparison(n2, i2)),
344+
right @ NumLitBinComparison(n3, i3)) =>
340345
if (n3 fastEquals n1) {
341346
Or(Or(e1, right), e2)
342-
} else {
347+
} else if (n3 fastEquals n2) {
343348
Or(Or(e2, right), e1)
349+
} else {
350+
or
344351
}
345352

346353
// (b > 5 && a < 2) && a > 3 => false, (a < 2 && b > 5) && a > 3 => false
347-
case And(And(
348-
e1 @ NumericLiteralBinaryComparison(n1, i1), e2 @ NumericLiteralBinaryComparison(n2, i2)),
349-
right @ NumericLiteralBinaryComparison(n3, i3)) =>
354+
case and @ And(
355+
And(e1 @ NumLitBinComparison(n1, i1), e2 @ NumLitBinComparison(n2, i2)),
356+
right @ NumLitBinComparison(n3, i3)) =>
350357
if (n3 fastEquals n1) {
351358
And(And(e1, right), e2)
352-
} else {
359+
} else if (n3 fastEquals n2) {
353360
And(And(e2, right), e1)
361+
} else {
362+
and
354363
}
355364

356365
// (a < 2 || b > 5) && a > 3 => b > 5 && a > 3
357-
case And(left1@Or(left2, right2), right1) =>
358-
Or(And(left2, right1), And(right2, right1))
359-
360-
// (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... =>
366+
// using formula: a && (b || c) = (a && b) || (a && c)
367+
case And(
368+
left1 @ Or(left2 @ NumLitBinComparison(n1, i1), right2 @ NumLitBinComparison(n2, i2)),
369+
right1 @ NumLitBinComparison(n3, i3))
370+
if ((n3.fastEquals(n1) && i3 != i1) || (n3.fastEquals(n2) && i3 != i2)) =>
371+
Or(And(left2, right1), And(right2, right1))
372+
373+
// (a < 3 && b > 5) || a > 2 => b > 5 || a > 2.
374+
// using formula: a || (b && c) = (a || b) && (a || c)
375+
case Or(
376+
left1 @ And(left2 @ NumLitBinComparison(n1, i1), right2 @ NumLitBinComparison(n2, i2)),
377+
right1 @ NumLitBinComparison(n3, i3))
378+
if ((n3.fastEquals(n1) && i3 != i1) || (n3.fastEquals(n2) && i3 != i2)) =>
379+
Or(And(left2, right1), And(right2, right1))
380+
381+
/** 4. And/Or whose one child is literal condition, the other is Or/And */
382+
// (a < 2 || b > 5) && a < 2 => a < 2
383+
case And(
384+
left1 @ Or(left2 @ NumLitBinComparison(_, _), right2 @ NumLitBinComparison(_, _)),
385+
right1 @ NumLitBinComparison(_, _))
386+
if (right1 fastEquals left2) || (right1 fastEquals right2) =>
387+
right1
388+
389+
// (a < 3 && b > 5) || a < 3 => a < 3
390+
case Or(
391+
left1 @ And(left2 @ NumLitBinComparison(_, _), right2 @ NumLitBinComparison(_, _)),
392+
right1 @ NumLitBinComparison(_, _))
393+
if (right1 fastEquals left2) || (right1 fastEquals right2) =>
394+
right1
395+
396+
// 5. (a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ... =>
361397
// a && b && ((c && ...) || (d && ...) || (e && ...) || ...)
362398
case or @ Or(left, right) =>
363399
val lhsSet = splitConjunctivePredicates(left).toSet
364400
val rhsSet = splitConjunctivePredicates(right).toSet
365401
val common = lhsSet.intersect(rhsSet)
366-
(lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And))
367-
.reduceOption(Or)
368-
.map(_ :: common.toList)
369-
.getOrElse(common.toList)
370-
.reduce(And)
402+
val ldiff = lhsSet.diff(common)
403+
val rdiff = rhsSet.diff(common)
404+
if (common.size == 0) {
405+
or
406+
}else if ( ldiff.size == 0 || rdiff == 0) {
407+
common.reduce(And)
408+
} else {
409+
(ldiff.reduceOption(And) ++ rdiff.reduceOption(And))
410+
.reduceOption(Or)
411+
.map(_ :: common.toList)
412+
.getOrElse(common.toList)
413+
.reduce(And)
414+
}
371415

372-
// (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... =>
416+
// 6. (a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ... =>
373417
// (a || b) || ((c || ...) && (f || ...) && (e || ...) && ...)
374418
case and @ And(left, right) =>
375419
val lhsSet = splitDisjunctivePredicates(left).toSet
376420
val rhsSet = splitDisjunctivePredicates(right).toSet
377421
val common = lhsSet.intersect(rhsSet)
378-
(lhsSet.diff(common).reduceOption(Or) ++ rhsSet.diff(common).reduceOption(Or))
379-
.reduceOption(And)
380-
.map(_ :: common.toList)
381-
.getOrElse(common.toList)
382-
.reduce(Or)
422+
val ldiff = lhsSet.diff(common)
423+
val rdiff = rhsSet.diff(common)
424+
425+
if (common.size == 0) {
426+
and
427+
}else if (ldiff.size == 0 || rdiff.size == 0) {
428+
common.reduce(Or)
429+
} else {
430+
val x = (ldiff.reduceOption(Or) ++ rdiff.reduceOption(Or))
431+
.reduceOption(And)
432+
.map(_ :: common.toList)
433+
.getOrElse(common.toList)
434+
.reduce(Or)
435+
x
436+
}
437+
438+
case other => other
383439
}
384440
}
385441

386442
private implicit class NumericLiteral(e: Literal) {
387443
def toDouble = Cast(e, DoubleType).eval().asInstanceOf[Double]
388444
}
389445

390-
object NumericLiteralBinaryComparison {
446+
object NumLitBinComparison {
391447
def unapply(e: Expression): Option[(NamedExpression, Interval[Double])] = e match {
392-
case LessThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.below(l.toDouble)))
448+
case LessThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.below(l.toDouble)))
393449
case LessThan(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.atOrAbove(l.toDouble)))
394450

395451
case GreaterThan(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.above(l.toDouble)))
@@ -402,6 +458,7 @@ object ConditionSimplification extends Rule[LogicalPlan] with PredicateHelper {
402458
case GreaterThanOrEqual(l @ Literal(_, _: NumericType), n: NamedExpression) => Some((n, Interval.below(l.toDouble)))
403459

404460
case EqualTo(n: NamedExpression, l @ Literal(_, _: NumericType)) => Some((n, Interval.point(l.toDouble)))
461+
case other => None
405462
}
406463
}
407464
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConditionSimplificationSuite.scala

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ConditionSimplificationSuite extends PlanTest {
3131
val batches =
3232
Batch("AnalysisNodes", Once,
3333
EliminateAnalysisOperators) ::
34-
Batch("Constant Folding", FixedPoint(10),
34+
Batch("Constant Folding", FixedPoint(50),
3535
NullPropagation,
3636
ConstantFolding,
3737
ConditionSimplification,
@@ -55,11 +55,12 @@ class ConditionSimplificationSuite extends PlanTest {
5555
comparePlans(optimized, expected)
5656
}
5757

58-
test("literal in front of attribute") {
59-
checkCondition(Literal(1) < 'a || Literal(2) < 'a, 'a > 1)
58+
test("literals in front of attribute") {
59+
checkCondition(Literal(1) < 'a || Literal(2) < 'a, Literal(1) < 'a)
60+
checkCondition(Literal(1) < 'a && Literal(2) < 'a, Literal(2) < 'a)
6061
}
6162

62-
test("combine the same condition") {
63+
test("And/Or with the same conditions") {
6364
checkCondition('a < 1 || 'a < 1, 'a < 1)
6465
checkCondition('a < 1 || 'a < 1 || 'a < 1 || 'a < 1, 'a < 1)
6566
checkCondition('a > 2 && 'a > 2, 'a > 2)
@@ -69,7 +70,7 @@ class ConditionSimplificationSuite extends PlanTest {
6970

7071
test("combine literal binary comparison") {
7172
checkCondition('a === 1 && 'a < 1)
72-
checkCondition('a === 1 || 'a < 1, 'a <= 1)
73+
checkCondition('a === 1 || 'a < 1, 'a === 1 || 'a < 1)
7374

7475
checkCondition('a === 1 && 'a === 2)
7576
checkCondition('a === 1 || 'a === 2, 'a === 1 || 'a === 2)
@@ -83,6 +84,10 @@ class ConditionSimplificationSuite extends PlanTest {
8384
checkCondition('a > 3 && 'a > 2, 'a > 3)
8485
checkCondition('a > 3 || 'a > 2, 'a > 2)
8586

87+
checkCondition('a < 2 || 'a === 3 , 'a < 2 || 'a === 3)
88+
checkCondition('a === 3 || 'a > 5, 'a === 3 || 'a > 5)
89+
checkCondition('a < 2 || 'a > 5, 'a < 2 || 'a > 5)
90+
8691
checkCondition('a >= 1 && 'a <= 1, 'a === 1)
8792

8893
}
@@ -103,56 +108,52 @@ class ConditionSimplificationSuite extends PlanTest {
103108
checkCondition('a < 1 || 'b > 2 || 'a >= 1)
104109
checkCondition('a < 1 && 'b > 2 && 'a >= 1)
105110

106-
checkCondition('a < 2 || 'b > 3 || 'b > 2, 'a < 2 || 'b > 2)
107-
checkCondition('a < 2 && 'b > 3 && 'b > 2, 'a < 2 && 'b > 3)
111+
checkCondition('a < 2 || 'b > 3 || 'b > 2, 'b > 2 || 'a < 2)
112+
checkCondition('a < 2 && 'b > 3 && 'b > 2, 'b > 3 && 'a < 2)
108113

109-
checkCondition('a < 2 || ('b > 3 || 'b > 2), 'b > 2 || 'a < 2)
110-
checkCondition('a < 2 && ('b > 3 && 'b > 2), 'b > 3 && 'a < 2)
114+
checkCondition('a < 2 || ('b > 3 || 'b > 2), 'a < 2 || 'b > 2)
115+
checkCondition('a < 2 && ('b > 3 && 'b > 2), 'a < 2 && 'b > 3)
111116

112117
checkCondition('a < 2 || 'a === 3 || 'a > 5, 'a < 2 || 'a === 3 || 'a > 5)
113118
}
114119

115120
test("combine predicate : 2 difference combine") {
116121
checkCondition(('a < 2 || 'a > 3) && 'a > 4, 'a > 4)
117122
checkCondition(('a < 2 || 'b > 3) && 'a < 2, 'a < 2)
118-
119-
checkCondition('a < 2 || ('a >= 2 && 'b > 1), 'b > 1 || 'a < 2)
120-
checkCondition('a < 2 || ('a === 2 && 'b > 1), 'a < 2 || ('a === 2 && 'b > 1))
121-
122-
checkCondition('a > 3 || ('a > 2 && 'a < 4), 'a > 2)
123+
checkCondition(('a < 2 && 'b > 3) || 'a < 2, 'a < 2)
123124
}
124125

125126
test("multi left, single right") {
126127
checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2)
127128
}
128129

129130
test("multi left, multi right") {
130-
checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5))
131+
checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), ('b > 3 && 'c > 5) || 'a < 2)
131132

132133
var input: Expression = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5)
133-
var expected: Expression = 'a === 'b || ('b > 3 && 'a > 3 && 'a < 5)
134+
var expected: Expression = ('a > 3 && 'a < 5 && 'b > 3) || 'a === 'b
134135
checkCondition(input, expected)
135136

136137
input = ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a > 1)
137-
expected = 'a === 'b || ('b > 3 && 'a > 3)
138+
expected = ('a > 3 && 'b > 3) || 'a === 'b
138139
checkCondition(input, expected)
139140

140141
input = ('a === 'b && 'b > 3 && 'c > 2) ||
141142
('a === 'b && 'c < 1 && 'a === 5) ||
142143
('a === 'b && 'b < 5 && 'a > 1)
143144

144-
expected = ('a === 'b) &&
145+
expected =
145146
(((('b > 3) && ('c > 2)) ||
146147
(('c < 1) && ('a === 5))) ||
147-
(('b < 5) && ('a > 1)))
148+
(('b < 5) && ('a > 1))) && ('a === 'b)
148149
checkCondition(input, expected)
149150

150151
input = ('a < 2 || 'b > 5 || 'a < 2 || 'b > 1) && ('a < 2 || 'b > 1)
151152
expected = 'a < 2 || 'b > 1
152153
checkCondition(input, expected)
153154

154155
input = ('a === 'b || 'b > 5) && ('a === 'b || 'c > 3) && ('a === 'b || 'b > 1)
155-
expected = ('a === 'b) || ('c > 3 && 'b > 5)
156+
expected = ('b > 5 && 'c > 3) || ('a === 'b)
156157
checkCondition(input, expected)
157158
}
158159
}

0 commit comments

Comments
 (0)