Skip to content

Commit fbeab54

Browse files
committed
Better toString, factories for AttributeSet.
1 parent cf1d32e commit fbeab54

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,26 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst.analysis.Star
21+
2022
protected class AttributeEquals(val a: Attribute) {
2123
override def hashCode() = a.exprId.hashCode()
22-
override def equals(other: Any) = other match {
23-
case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
24-
case otherAttribute => false
24+
override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
25+
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
26+
case (a1, a2) => a1 == a2
2527
}
2628
}
2729

2830
object AttributeSet {
29-
/** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
30-
def apply(baseSet: Seq[Attribute]) = {
31-
new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
32-
}
31+
def apply(a: Attribute) =
32+
new AttributeSet(Set(new AttributeEquals(a)))
33+
34+
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
35+
def apply(baseSet: Seq[Expression]) =
36+
new AttributeSet(
37+
baseSet
38+
.flatMap(_.references)
39+
.map(new AttributeEquals(_)).toSet)
3340
}
3441

3542
/**
@@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
103110
// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
104111
// sorts of things in its closure.
105112
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
113+
114+
override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}"
106115
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ abstract class NamedExpression extends Expression {
5757
abstract class Attribute extends NamedExpression {
5858
self: Product =>
5959

60+
override def references = AttributeSet(this)
61+
6062
def withNullability(newNullability: Boolean): Attribute
6163
def withQualifiers(newQualifiers: Seq[String]): Attribute
6264
def withName(newName: String): Attribute
@@ -116,8 +118,6 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
116118
(val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
117119
extends Attribute with trees.LeafNode[Expression] {
118120

119-
override def references = AttributeSet(this :: Nil)
120-
121121
override def equals(other: Any) = other match {
122122
case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
123123
case _ => false

0 commit comments

Comments
 (0)