Skip to content

Commit 217e496

Browse files
zsxwingAndrew Or
authored andcommitted
[SPARK-9996] [SPARK-9997] [SQL] Add local expand and NestedLoopJoin operators
This PR is in conflict with #8535 and #8573. Will update this one when they are merged. Author: zsxwing <[email protected]> Closes #8642 from zsxwing/expand-nest-join.
1 parent 64f0415 commit 217e496

File tree

7 files changed

+574
-15
lines changed

7 files changed

+574
-15
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
import org.apache.spark.sql.SQLConf
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection}
23+
24+
case class ExpandNode(
25+
conf: SQLConf,
26+
projections: Seq[Seq[Expression]],
27+
output: Seq[Attribute],
28+
child: LocalNode) extends UnaryLocalNode(conf) {
29+
30+
assert(projections.size > 0)
31+
32+
private[this] var result: InternalRow = _
33+
private[this] var idx: Int = _
34+
private[this] var input: InternalRow = _
35+
private[this] var groups: Array[Projection] = _
36+
37+
override def open(): Unit = {
38+
child.open()
39+
groups = projections.map(ee => newProjection(ee, child.output)).toArray
40+
idx = groups.length
41+
}
42+
43+
override def next(): Boolean = {
44+
if (idx >= groups.length) {
45+
if (child.next()) {
46+
input = child.fetch()
47+
idx = 0
48+
} else {
49+
return false
50+
}
51+
}
52+
result = groups(idx)(input)
53+
idx += 1
54+
true
55+
}
56+
57+
override def fetch(): InternalRow = result
58+
59+
override def close(): Unit = child.close()
60+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.Logging
2323
import org.apache.spark.sql.{SQLConf, Row}
2424
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2525
import org.apache.spark.sql.catalyst.expressions._
26-
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
26+
import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.trees.TreeNode
2828
import org.apache.spark.sql.types.StructType
2929

@@ -69,6 +69,18 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
6969
*/
7070
def close(): Unit
7171

72+
/** Specifies whether this operator outputs UnsafeRows */
73+
def outputsUnsafeRows: Boolean = false
74+
75+
/** Specifies whether this operator is capable of processing UnsafeRows */
76+
def canProcessUnsafeRows: Boolean = false
77+
78+
/**
79+
* Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows
80+
* that are not UnsafeRows).
81+
*/
82+
def canProcessSafeRows: Boolean = true
83+
7284
/**
7385
* Returns the content through the [[Iterator]] interface.
7486
*/
@@ -91,6 +103,28 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
91103
result
92104
}
93105

106+
protected def newProjection(
107+
expressions: Seq[Expression],
108+
inputSchema: Seq[Attribute]): Projection = {
109+
log.debug(
110+
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
111+
if (codegenEnabled) {
112+
try {
113+
GenerateProjection.generate(expressions, inputSchema)
114+
} catch {
115+
case NonFatal(e) =>
116+
if (isTesting) {
117+
throw e
118+
} else {
119+
log.error("Failed to generate projection, fallback to interpret", e)
120+
new InterpretedProjection(expressions, inputSchema)
121+
}
122+
}
123+
} else {
124+
new InterpretedProjection(expressions, inputSchema)
125+
}
126+
}
127+
94128
protected def newMutableProjection(
95129
expressions: Seq[Expression],
96130
inputSchema: Seq[Attribute]): () => MutableProjection = {
@@ -113,6 +147,25 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
113147
}
114148
}
115149

150+
protected def newPredicate(
151+
expression: Expression,
152+
inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
153+
if (codegenEnabled) {
154+
try {
155+
GeneratePredicate.generate(expression, inputSchema)
156+
} catch {
157+
case NonFatal(e) =>
158+
if (isTesting) {
159+
throw e
160+
} else {
161+
log.error("Failed to generate predicate, fallback to interpreted", e)
162+
InterpretedPredicate.create(expression, inputSchema)
163+
}
164+
}
165+
} else {
166+
InterpretedPredicate.create(expression, inputSchema)
167+
}
168+
}
116169
}
117170

118171

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
import org.apache.spark.sql.SQLConf
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType}
24+
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
25+
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
26+
27+
case class NestedLoopJoinNode(
28+
conf: SQLConf,
29+
left: LocalNode,
30+
right: LocalNode,
31+
buildSide: BuildSide,
32+
joinType: JoinType,
33+
condition: Option[Expression]) extends BinaryLocalNode(conf) {
34+
35+
override def output: Seq[Attribute] = {
36+
joinType match {
37+
case LeftOuter =>
38+
left.output ++ right.output.map(_.withNullability(true))
39+
case RightOuter =>
40+
left.output.map(_.withNullability(true)) ++ right.output
41+
case FullOuter =>
42+
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
43+
case x =>
44+
throw new IllegalArgumentException(
45+
s"NestedLoopJoin should not take $x as the JoinType")
46+
}
47+
}
48+
49+
private[this] def genResultProjection: InternalRow => InternalRow = {
50+
if (outputsUnsafeRows) {
51+
UnsafeProjection.create(schema)
52+
} else {
53+
identity[InternalRow]
54+
}
55+
}
56+
57+
private[this] var currentRow: InternalRow = _
58+
59+
private[this] var iterator: Iterator[InternalRow] = _
60+
61+
override def open(): Unit = {
62+
val (streamed, build) = buildSide match {
63+
case BuildRight => (left, right)
64+
case BuildLeft => (right, left)
65+
}
66+
build.open()
67+
val buildRelation = new CompactBuffer[InternalRow]
68+
while (build.next()) {
69+
buildRelation += build.fetch().copy()
70+
}
71+
build.close()
72+
73+
val boundCondition =
74+
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
75+
76+
val leftNulls = new GenericMutableRow(left.output.size)
77+
val rightNulls = new GenericMutableRow(right.output.size)
78+
val joinedRow = new JoinedRow
79+
val matchedBuildTuples = new BitSet(buildRelation.size)
80+
val resultProj = genResultProjection
81+
streamed.open()
82+
83+
// streamedRowMatches also contains null rows if using outer join
84+
val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow =>
85+
val matchedRows = new CompactBuffer[InternalRow]
86+
87+
var i = 0
88+
var streamRowMatched = false
89+
90+
// Scan the build relation to look for matches for each streamed row
91+
while (i < buildRelation.size) {
92+
val buildRow = buildRelation(i)
93+
buildSide match {
94+
case BuildRight => joinedRow(streamedRow, buildRow)
95+
case BuildLeft => joinedRow(buildRow, streamedRow)
96+
}
97+
if (boundCondition(joinedRow)) {
98+
matchedRows += resultProj(joinedRow).copy()
99+
streamRowMatched = true
100+
matchedBuildTuples.set(i)
101+
}
102+
i += 1
103+
}
104+
105+
// If this row had no matches and we're using outer join, join it with the null rows
106+
if (!streamRowMatched) {
107+
(joinType, buildSide) match {
108+
case (LeftOuter | FullOuter, BuildRight) =>
109+
matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy()
110+
case (RightOuter | FullOuter, BuildLeft) =>
111+
matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy()
112+
case _ =>
113+
}
114+
}
115+
116+
matchedRows.iterator
117+
}
118+
119+
// If we're using outer join, find rows on the build side that didn't match anything
120+
// and join them with the null row
121+
lazy val unmatchedBuildRows: Iterator[InternalRow] = {
122+
var i = 0
123+
buildRelation.filter { row =>
124+
val r = !matchedBuildTuples.get(i)
125+
i += 1
126+
r
127+
}.iterator
128+
}
129+
iterator = (joinType, buildSide) match {
130+
case (RightOuter | FullOuter, BuildRight) =>
131+
streamedRowMatches ++
132+
unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) }
133+
case (LeftOuter | FullOuter, BuildLeft) =>
134+
streamedRowMatches ++
135+
unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) }
136+
case _ => streamedRowMatches
137+
}
138+
}
139+
140+
override def next(): Boolean = {
141+
if (iterator.hasNext) {
142+
currentRow = iterator.next()
143+
true
144+
} else {
145+
false
146+
}
147+
}
148+
149+
override def fetch(): InternalRow = currentRow
150+
151+
override def close(): Unit = {
152+
left.close()
153+
right.close()
154+
}
155+
156+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
class ExpandNodeSuite extends LocalNodeTest {
21+
22+
import testImplicits._
23+
24+
test("expand") {
25+
val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
26+
checkAnswer(
27+
input,
28+
node =>
29+
ExpandNode(conf, Seq(
30+
Seq(
31+
input.col("key") + input.col("value"), input.col("key") - input.col("value")
32+
).map(_.expr),
33+
Seq(
34+
input.col("key") * input.col("value"), input.col("key") / input.col("value")
35+
).map(_.expr)
36+
), node.output, node),
37+
Seq(
38+
(2, 0),
39+
(1, 1),
40+
(4, 0),
41+
(4, 1),
42+
(6, 0),
43+
(9, 1),
44+
(8, 0),
45+
(16, 1),
46+
(10, 0),
47+
(25, 1)
48+
).toDF().collect()
49+
)
50+
}
51+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,6 @@ class HashJoinNodeSuite extends LocalNodeTest {
2424

2525
import testImplicits._
2626

27-
private def wrapForUnsafe(
28-
f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
29-
if (conf.unsafeEnabled) {
30-
(left: LocalNode, right: LocalNode) => {
31-
val _left = ConvertToUnsafeNode(conf, left)
32-
val _right = ConvertToUnsafeNode(conf, right)
33-
val r = f(_left, _right)
34-
ConvertToSafeNode(conf, r)
35-
}
36-
} else {
37-
f
38-
}
39-
}
40-
4127
def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
4228
test(s"$suiteName: inner join with one match per row") {
4329
withSQLConf(confPairs: _*) {

sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext {
2727

2828
def conf: SQLConf = sqlContext.conf
2929

30+
protected def wrapForUnsafe(
31+
f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
32+
if (conf.unsafeEnabled) {
33+
(left: LocalNode, right: LocalNode) => {
34+
val _left = ConvertToUnsafeNode(conf, left)
35+
val _right = ConvertToUnsafeNode(conf, right)
36+
val r = f(_left, _right)
37+
ConvertToSafeNode(conf, r)
38+
}
39+
} else {
40+
f
41+
}
42+
}
43+
3044
/**
3145
* Runs the LocalNode and makes sure the answer matches the expected result.
3246
* @param input the input data to be used.

0 commit comments

Comments
 (0)