Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,26 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* additional constraint of the form `b = 5`
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
// Collect alias from expressions to avoid producing non-converging set of constraints
// for recursive functions.
//
// Don't apply transform on constraints if the attribute used to replace is an alias,
// because then both `QueryPlan.inferAdditionalConstraints` and
// `UnaryNode.getAliasedConstraints` applies and may produce a non-converging set of
// constraints.
// For more details, infer https://issues.apache.org/jira/browse/SPARK-17733
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo "infer" -> "refer" (to)?

val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using AttributeSet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, AttributeSet is a better choice here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since aliasMap is referenced at a number of places, let's just make this a private lazy val and move it outside of this method in QueryPlan.

case a: Alias => (a.toAttribute, a.child)
})

var inferredConstraints = Set.empty[Expression]
constraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
inferredConstraints ++= (constraints - eq).map(_ transform {
case a: Attribute if a.semanticEquals(l) => r
case a: Attribute if a.semanticEquals(l) && !aliasMap.contains(r) => r
})
inferredConstraints ++= (constraints - eq).map(_ transform {
case a: Attribute if a.semanticEquals(r) => l
case a: Attribute if a.semanticEquals(r) && !aliasMap.contains(l) => l
})
case _ => // No inference
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.rules._
class InferFiltersFromConstraintsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) ::
Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) ::
Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil
val batches =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous batches will not apply InferFiltersFromConstraints after PushPredicateThroughJoin.

Batch("InferAndPushDownFilters", FixedPoint(100),
PushPredicateThroughJoin,
PushDownPredicate,
InferFiltersFromConstraints,
CombineFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down Expand Up @@ -120,4 +123,64 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("inner join with alias: alias contains multiple attributes") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)

val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val currectAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo correctAnswer

&&'a === Coalesce(Seq('a, 'b)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2 spaces

.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2.where(IsNotNull('a)), Inner,
Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, currectAnswer)
}

test("inner join with alias: alias contains single attributes") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)

val originalQuery = t1.select('a, 'b.as('d)).as("t")
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val currectAnswer = t1.where(IsNotNull('a) && IsNotNull('b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo

&& 'a <=> 'a && 'b <=> 'b &&'a === 'b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2 spaces

.select('a, 'b.as('d)).as("t")
.join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, currectAnswer)
}

test("inner join with alias: don't generate constraints for recursive functions") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)

val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2, Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
&& 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b)
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._

/**
Expand Down Expand Up @@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
* ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
* etc., will all now be equivalent.
* - Sample the seed will replaced by 0L.
* - Join conditions will be resorted by hashCode.
*/
private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
.reduce(And), child)
case sample: Sample =>
sample.copy(seed = 0L)(true)
case join @ Join(left, right, joinType, condition) if condition.isDefined =>
val newCondition =
splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
.reduce(And)
Join(left, right, joinType, Some(newCondition))
}
}

/**
* Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
* equivalent:
* 1. (a = b), (b = a);
* 2. (a <=> b), (b <=> a).
*/
private def rewriteEqual(condition: Expression): Expression = condition match {
case eq @ EqualTo(l: Expression, r: Expression) =>
Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
case eq @ EqualNullSafe(l: Expression, r: Expression) =>
Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
case _ => condition // Don't reorder.
}

/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
val normalized1 = normalizePlan(normalizeExprIds(plan1))
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2678,4 +2678,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-17733 InferFiltersFromConstraints rule never terminates for query") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we construct a unit test rather than an end-to-end test here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - Perhaps we could add new testcases in InferFiltersFromConstraintsSuite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that you already have a unit test for cases like these, how about we remove this now? This test was randomly generated to catch issues like this and in its current form, it isn't very obvious how this query has anything to do with InferFiltersFromConstraints.

withTempView("tmpv") {
spark.range(10).toDF("a").createTempView("tmpv")

// Just ensure the following query will successfully execute complete.
assert(sql(
"""
|SELECT
| *
|FROM (
| SELECT
| COALESCE(t1.a, t2.a) AS int_col,
| t1.a,
| t2.a AS b
| FROM tmpv t1
| CROSS JOIN tmpv t2
|) t1
|INNER JOIN tmpv t2
|ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b))
""".stripMargin).count() > 0
)
}
}
}