Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ trait ConstraintHelper {
val candidateConstraints = constraints - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case eq @ EqualTo(l: Attribute, r : Literal) =>
inferredConstraints ++= replaceConstraints(constraints - eq, l, r)
case eq @ EqualTo(l : Literal, r: Attribute) =>
inferredConstraints ++= replaceConstraints(constraints - eq, r, l)
case _ => // No inference
}
inferredConstraints -- constraints
Expand All @@ -75,7 +79,7 @@ trait ConstraintHelper {
private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
destination: Attribute): Set[Expression] = constraints.map(_ transform {
destination: Expression): Set[Expression] = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,16 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val y = testRelation.subquery('y)
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
}

test("SPARK-30768: Constraints should be inferred from inequality attributes") {
val condition = Some("x.a".attr > "y.a".attr)
val optimizedLeft = testRelation.where(IsNotNull('a) && 'a === 1).as("x")
val optimizedRight = testRelation.where(Literal(1) > 'a && IsNotNull('a) ).as("y")
val correct = optimizedLeft.join(optimizedRight, Inner, condition)

Seq(Literal(1) === 'a, 'a === Literal(1)).foreach { filter =>
val original = testRelation.where(filter).as("x").join(testRelation.as("y"), Inner, condition)
comparePlans(Optimize.execute(original.analyze), correct.analyze)
}
}
}