Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet

/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
Expand Down Expand Up @@ -142,25 +141,6 @@ class Analyzer(
}

object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
/**
* Extract attribute set according to the grouping id
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
bit -= 1
}

set
}

/*
* GROUP BY a, b, c WITH ROLLUP
* is equivalent to
Expand Down Expand Up @@ -197,10 +177,15 @@ class Analyzer(

g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
var bit = g.groupByExprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
bit -= 1
}

val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.contains(e) => // OK
case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ abstract class Expression extends TreeNode[Expression] {
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}

/**
* Returns true if 2 expressions are equal in semantic, which is similar to equals method
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns true when two expressions will always compute the same result, even if they differ cosmetically (i.e. capitalization of names in attributes may be different).

* but has different definition on some leaf expressions like AttributeReference.
*/
def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala doc please.

val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
elements1.length == elements2.length && elements1.zip(elements2).forall {
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
case (i1, i2) => i1 == i2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about case class Coalesce(children: Seq[Expression])?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difficulty here is we probably never knows semanticEquals in a general way, that's why I said we need to re-implemented for lots of expressions if we added this.

}
}
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ case class AttributeReference(
case _ => false
}

override def semanticEquals(other: Expression): Boolean = other match {
case ar: AttributeReference => sameRef(ar)
case _ => false
}

override def hashCode: Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var h = 17
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,10 @@ object PartialAggregation {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions
.get(e.transform { case Alias(g: ExtractValue, _) => g })
.map(_.toAttribute)
.find { case (k, v) => k semanticEquals trimmed }
.map(_._2.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,4 +773,22 @@ class SQLQuerySuite extends QueryTest {
| select * from v2 order by key limit 1
""".stripMargin), Row(0, 3))
}

test("SPARK-7269 Check analysis failed in case in-sensitive") {
Seq(1, 2, 3).map { i =>
(i.toString, i.toString)
}.toDF("key", "value").registerTempTable("df_analysis")
sql("SELECT kEy from df_analysis group by key").collect()
sql("SELECT kEy+3 from df_analysis group by key+3").collect()
sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect()
sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect()
sql("SELECT 2 from df_analysis A group by key+1").collect()
intercept[AnalysisException] {
sql("SELECT kEy+1 from df_analysis group by key+3")
}
intercept[AnalysisException] {
sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)")
}
}
}