Skip to content

Commit 4b18418

Browse files
committed
Fix a bug and add unit tests
1 parent 514d3d2 commit 4b18418

File tree

2 files changed

+62
-10
lines changed

2 files changed

+62
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ case class ExpandNode(
3232
private[this] var result: InternalRow = _
3333
private[this] var idx: Int = _
3434
private[this] var input: InternalRow = _
35-
3635
private[this] var groups: Array[Projection] = _
3736

3837
override def open(): Unit = {
@@ -42,17 +41,19 @@ case class ExpandNode(
4241
}
4342

4443
override def next(): Boolean = {
45-
idx += 1
46-
if (idx < groups.length) {
47-
result = groups(idx)(input)
48-
true
49-
} else if (child.next()) {
50-
input = child.fetch()
51-
idx = 0
44+
if (idx < 0 || idx >= groups.length) {
45+
if (child.next()) {
46+
input = child.fetch()
47+
result = groups(0)(input)
48+
idx = 1
49+
true
50+
} else {
51+
false
52+
}
53+
} else {
5254
result = groups(idx)(input)
55+
idx += 1
5356
true
54-
} else {
55-
false
5657
}
5758
}
5859

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
class ExpandNodeSuite extends LocalNodeTest {
21+
22+
import testImplicits._
23+
24+
test("expand") {
25+
val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
26+
checkAnswer(
27+
input,
28+
node =>
29+
ExpandNode(conf, Seq(
30+
Seq(
31+
input.col("key") + input.col("value"), input.col("key") - input.col("value")
32+
).map(_.expr),
33+
Seq(
34+
input.col("key") * input.col("value"), input.col("key") / input.col("value")
35+
).map(_.expr)
36+
), node.output, node),
37+
Seq(
38+
(2, 0),
39+
(1, 1),
40+
(4, 0),
41+
(4, 1),
42+
(6, 0),
43+
(9, 1),
44+
(8, 0),
45+
(16, 1),
46+
(10, 0),
47+
(25, 1)
48+
).toDF().collect()
49+
)
50+
}
51+
}

0 commit comments

Comments
 (0)