Skip to content

Commit 276ef1c

Browse files
committed
[SPARK-6463][SQL] AttributeSet.equal should compare size
Previously this could result in sets compare equals when in fact the right was a subset of the left. Based on apache#5133 by sisihj Author: sisihj <[email protected]> Author: Michael Armbrust <[email protected]> Closes apache#5194 from marmbrus/pr/5133 and squashes the following commits: 5ed4615 [Michael Armbrust] fix imports d4cbbc0 [Michael Armbrust] Add test cases 0a0834f [sisihj] AttributeSet.equal should compare size
1 parent e87bf37 commit 276ef1c

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
5858

5959
/** Returns true if the members of this AttributeSet and other are the same. */
6060
override def equals(other: Any): Boolean = other match {
61-
case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains)
61+
case otherSet: AttributeSet =>
62+
otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains)
6263
case _ => false
6364
}
6465

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.sql.types.IntegerType
23+
24+
class AttributeSetSuite extends FunSuite {
25+
26+
val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
27+
val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
28+
val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
29+
val aSet = AttributeSet(aLower :: Nil)
30+
31+
val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
32+
val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
33+
val bSet = AttributeSet(bUpper :: Nil)
34+
35+
val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)
36+
37+
test("sanity check") {
38+
assert(aUpper != aLower)
39+
assert(bUpper != bLower)
40+
}
41+
42+
test("checks by id not name") {
43+
assert(aSet.contains(aUpper) === true)
44+
assert(aSet.contains(aLower) === true)
45+
assert(aSet.contains(fakeA) === false)
46+
47+
assert(aSet.contains(bUpper) === false)
48+
assert(aSet.contains(bLower) === false)
49+
}
50+
51+
test("++ preserves AttributeSet") {
52+
assert((aSet ++ bSet).contains(aUpper) === true)
53+
assert((aSet ++ bSet).contains(aLower) === true)
54+
}
55+
56+
test("extracts all references references") {
57+
val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
58+
assert(addSet.contains(aUpper))
59+
assert(addSet.contains(aLower))
60+
assert(addSet.contains(bUpper))
61+
assert(addSet.contains(bLower))
62+
}
63+
64+
test("dedups attributes") {
65+
assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
66+
}
67+
68+
test("subset") {
69+
assert(aSet.subsetOf(aAndBSet) === true)
70+
assert(aAndBSet.subsetOf(aSet) === false)
71+
}
72+
73+
test("equality") {
74+
assert(aSet != aAndBSet)
75+
assert(aAndBSet != aSet)
76+
assert(aSet != bSet)
77+
assert(bSet != aSet)
78+
79+
assert(aSet == aSet)
80+
assert(aSet == AttributeSet(aUpper :: Nil))
81+
}
82+
}

0 commit comments

Comments
 (0)