diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c5a1437be6d0..8a60af4ebcfd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -85,12 +85,12 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty => + case e: Attribute if groupingExprs.find(_ == e).isEmpty => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK + case e if groupingExprs.find(_ == e).isDefined => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 5345696570b4..563316e898fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -18,27 +18,14 @@ package org.apache.spark.sql.catalyst.expressions -protected class AttributeEquals(val a: Attribute) { - override def hashCode(): Int = a match { - case ar: AttributeReference => ar.exprId.hashCode() - case a => a.hashCode() - } - - override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match { - case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId - case (a1, a2) => a1 == a2 - } -} - object AttributeSet { - def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) + def apply(a: Attribute): AttributeSet = new AttributeSet(Set(a)) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ def apply(baseSet: Iterable[Expression]): AttributeSet = { new AttributeSet( baseSet - .flatMap(_.references) - .map(new AttributeEquals(_)).toSet) + .flatMap(_.references).toSet) } } @@ -53,30 +40,30 @@ object AttributeSet { * and also makes doing transformations hard (we always try keep older trees instead of new ones * when the transformation was a no-op). */ -class AttributeSet private (val baseSet: Set[AttributeEquals]) +class AttributeSet private (val baseSet: Set[Attribute]) extends Traversable[Attribute] with Serializable { /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any): Boolean = other match { case otherSet: AttributeSet => - otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) + otherSet.size == baseSet.size && baseSet.forall(otherSet.contains) case _ => false } /** Returns true if this set contains an Attribute with the same expression id as `elem` */ def contains(elem: NamedExpression): Boolean = - baseSet.contains(new AttributeEquals(elem.toAttribute)) + baseSet.contains(elem.toAttribute) /** Returns a new [[AttributeSet]] that contains `elem` in addition to the current elements. */ def +(elem: Attribute): AttributeSet = // scalastyle:ignore - new AttributeSet(baseSet + new AttributeEquals(elem)) + new AttributeSet(baseSet + elem) /** Returns a new [[AttributeSet]] that does not contain `elem`. */ def -(elem: Attribute): AttributeSet = - new AttributeSet(baseSet - new AttributeEquals(elem)) + new AttributeSet(baseSet - elem) /** Returns an iterator containing all of the attributes in the set. */ - def iterator: Iterator[Attribute] = baseSet.map(_.a).iterator + def iterator: Iterator[Attribute] = baseSet.iterator /** * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in @@ -89,7 +76,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * in `other`. */ def --(other: Traversable[NamedExpression]): AttributeSet = - new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) + new AttributeSet(baseSet -- other.map(_.toAttribute)) /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found @@ -102,7 +89,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * true. */ override def filter(f: Attribute => Boolean): AttributeSet = - new AttributeSet(baseSet.filter(ae => f(ae.a))) + new AttributeSet(baseSet.filter(f)) /** * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in @@ -111,13 +98,13 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) def intersect(other: AttributeSet): AttributeSet = new AttributeSet(baseSet.intersect(other.baseSet)) - override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) + override def foreach[U](f: (Attribute) => U): Unit = baseSet.foreach(f) // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all // sorts of things in its closure. - override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq + override def toSeq: Seq[Attribute] = baseSet.toArray.toSeq - override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}" + override def toString: String = "{" + baseSet.mkString(", ") + "}" override def isEmpty: Boolean = baseSet.isEmpty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a10a959ae766..7d3bd6036364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -141,24 +141,6 @@ abstract class Expression extends TreeNode[Expression] { }.toString } - /** - * Returns true when two expressions will always compute the same result, even if they differ - * cosmetically (i.e. capitalization of names in attributes may be different). - */ - def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { - def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 - case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) - case (i1, i2) => i1 == i2 - } - } - val elements1 = this.productIterator.toSeq - val elements2 = other.asInstanceOf[Product].productIterator.toSeq - checkSemantic(elements1, elements2) - } - /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9cacdceb1383..7e19e8b10b03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -181,12 +181,8 @@ case class AttributeReference( def sameRef(other: AttributeReference): Boolean = this.exprId == other.exprId override def equals(other: Any): Boolean = other match { - case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType - case _ => false - } - - override def semanticEquals(other: Expression): Boolean = other match { - case ar: AttributeReference => sameRef(ar) + case ar: AttributeReference => + exprId == ar.exprId && dataType == ar.dataType && metadata == ar.metadata case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 179a348d5baa..b7c1cc5217be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -151,13 +151,13 @@ object PartialAggregation { // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression => namedGroupingExpressions.collectFirst { - case (expr, ne) if expr semanticEquals e => ne.toAttribute + case (expr, ne) if expr == e => ne.toAttribute }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 09f6c6b0ec42..ca49617da63d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -154,7 +154,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) - if (newChild fastEquals arg) { + if (newChild eq arg) { arg } else { changed = true @@ -181,7 +181,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { + if (newChild eq oldChild) { oldChild } else { changed = true @@ -193,7 +193,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { + if (newChild eq oldChild) { oldChild } else { changed = true @@ -228,7 +228,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } // Check if unchanged and then possibly return old copy to avoid gc churn. - if (this fastEquals afterRule) { + if (this eq afterRule) { transformChildrenDown(rule) } else { afterRule.transformChildrenDown(rule) @@ -245,7 +245,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) - if (!(newChild fastEquals arg)) { + if (newChild ne arg) { changed = true newChild } else { @@ -253,7 +253,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } case Some(arg: TreeNode[_]) if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) - if (!(newChild fastEquals arg)) { + if (newChild ne arg) { changed = true Some(newChild) } else { @@ -264,7 +264,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) - if (!(newChild fastEquals arg)) { + if (newChild ne arg) { changed = true newChild } else { @@ -286,7 +286,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { val afterRuleOnChildren = transformChildrenUp(rule) - if (this fastEquals afterRuleOnChildren) { + if (this eq afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) } @@ -302,7 +302,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { + if (newChild ne arg) { changed = true newChild } else { @@ -310,7 +310,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } case Some(arg: TreeNode[_]) if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { + if (newChild ne arg) { changed = true Some(newChild) } else { @@ -321,7 +321,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { + if (newChild ne arg) { changed = true newChild } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala index 97cfb5f06dd7..46c5ba222fa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -34,8 +34,8 @@ class AttributeSetSuite extends SparkFunSuite { val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) test("sanity check") { - assert(aUpper != aLower) - assert(bUpper != bLower) + assert(aUpper == aLower) + assert(bUpper == bLower) } test("checks by id not name") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index ba2c8f53d702..6f96ecf5eed5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -224,7 +224,7 @@ case class GeneratedAggregate( case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) case e: Expression => namedGroups.collectFirst { - case (expr, attr) if expr semanticEquals e => attr + case (expr, attr) if expr == e => attr }.getOrElse(e) })