|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.expressions |
19 | 19 |
|
| 20 | +import org.apache.spark.sql.catalyst.analysis.Star |
| 21 | + |
20 | 22 | protected class AttributeEquals(val a: Attribute) { |
21 | 23 | override def hashCode() = a.exprId.hashCode() |
22 | | - override def equals(other: Any) = other match { |
23 | | - case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId |
24 | | - case otherAttribute => false |
| 24 | + override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { |
| 25 | + case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId |
| 26 | + case (a1, a2) => a1 == a2 |
25 | 27 | } |
26 | 28 | } |
27 | 29 |
|
28 | 30 | object AttributeSet { |
29 | | - /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */ |
30 | | - def apply(baseSet: Seq[Attribute]) = { |
31 | | - new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet) |
32 | | - } |
| 31 | + def apply(a: Attribute) = |
| 32 | + new AttributeSet(Set(new AttributeEquals(a))) |
| 33 | + |
| 34 | + /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ |
| 35 | + def apply(baseSet: Seq[Expression]) = |
| 36 | + new AttributeSet( |
| 37 | + baseSet |
| 38 | + .flatMap(_.references) |
| 39 | + .map(new AttributeEquals(_)).toSet) |
33 | 40 | } |
34 | 41 |
|
35 | 42 | /** |
@@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) |
103 | 110 | // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all |
104 | 111 | // sorts of things in its closure. |
105 | 112 | override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq |
| 113 | + |
| 114 | + override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}" |
106 | 115 | } |
0 commit comments