Skip to content

Commit 8e1a2f4

Browse files
committed
Address backward compatibility with old state
1 parent 47d7a29 commit 8e1a2f4

File tree

3 files changed

+56
-20
lines changed

3 files changed

+56
-20
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,43 @@ case class StreamingSymmetricHashJoinExec(
266266
case Inner =>
267267
innerOutputIter
268268
case LeftOuter =>
269+
// We generate the outer join input by:
270+
// * Getting an iterator over the rows that have aged out on the left side. These rows are
271+
// candidates for being null joined. Note that to avoid doing two passes, this iterator
272+
// removes the rows from the state manager as they're processed.
273+
// * Checking whether the current row matches a key in the right side state, and that key
274+
// has any value which satisfies the filter function when joined. If it doesn't,
275+
// we know we can join with null, since there was never (including this batch) a match
276+
// within the watermark period. If it does, there must have been a match at some point, so
277+
// we know we can't join with null.
278+
def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = {
279+
rightSideJoiner.get(leftKeyValue.key).exists { rightValue =>
280+
postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue))
281+
}
282+
}
283+
269284
val removedRowIter = leftSideJoiner.removeOldState()
270-
val outerOutputIter = removedRowIter.filterNot(_.matched)
271-
.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
285+
val outerOutputIter = removedRowIter.filterNot { kvAndMatched =>
286+
kvAndMatched.matched.getOrElse(
287+
// fail-back for previous state on SPARK-26154
288+
matchesWithRightSideState(new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value)))
289+
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
272290

273291
innerOutputIter ++ outerOutputIter
274292
case RightOuter =>
293+
// See comments for left outer case.
294+
def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = {
295+
leftSideJoiner.get(rightKeyValue.key).exists { leftValue =>
296+
postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
297+
}
298+
}
299+
275300
val removedRowIter = rightSideJoiner.removeOldState()
276-
val outerOutputIter = removedRowIter.filterNot(_.matched)
277-
.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
301+
val outerOutputIter = removedRowIter.filterNot { kvAndMatched =>
302+
kvAndMatched.matched.getOrElse(
303+
// fail-back for previous state on SPARK-26154
304+
matchesWithLeftSideState(new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value)))
305+
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
278306

279307
innerOutputIter ++ outerOutputIter
280308
case _ => throwBadJoinTypeException()
@@ -465,6 +493,15 @@ case class StreamingSymmetricHashJoinExec(
465493
}
466494
}
467495

496+
/**
497+
* Get an iterator over the values stored in this joiner's state manager for the given key.
498+
*
499+
* Should not be interleaved with mutations.
500+
*/
501+
def get(key: UnsafeRow): Iterator[UnsafeRow] = {
502+
joinStateManager.get(key)
503+
}
504+
468505
/**
469506
* Builds an iterator over old state key-value pairs, removing them lazily as they're produced.
470507
*

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,6 @@ class SymmetricHashJoinStateManager(
105105
}.filter(_ != null)
106106
}
107107

108-
/** Append a new value to the key */
109-
def append(key: UnsafeRow, value: UnsafeRow): Unit = {
110-
val numExistingValues = keyToNumValues.get(key)
111-
keyWithIndexToValue.put(key, numExistingValues, value)
112-
keyToNumValues.put(key, numExistingValues + 1)
113-
}
114-
115108
/** Append a new value to the key, with marking matched */
116109
def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit = {
117110
val numExistingValues = keyToNumValues.get(key)
@@ -223,7 +216,7 @@ class SymmetricHashJoinStateManager(
223216

224217
// Find the next value satisfying the condition, updating `currentKey` and `numValues` if
225218
// needed. Returns null when no value can be found.
226-
private def findNextValueForIndex(): (UnsafeRow, Boolean) = {
219+
private def findNextValueForIndex(): (UnsafeRow, Option[Boolean]) = {
227220
// Loop across all values for the current key, and then all other keys, until we find a
228221
// value satisfying the removal condition.
229222
def hasMoreValuesForCurrentKey = currentKey != null && index < numValues
@@ -273,9 +266,13 @@ class SymmetricHashJoinStateManager(
273266
keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex)
274267
keyWithIndexToValue.remove(currentKey, numValues - 1)
275268

276-
val matchedAtMaxIndex = keyWithIndexToMatched.get(currentKey, numValues - 1)
277-
keyWithIndexToMatched.put(currentKey, index, matchedAtMaxIndex)
278-
keyWithIndexToMatched.remove(currentKey, numValues - 1)
269+
keyWithIndexToMatched.get(currentKey, numValues - 1) match {
270+
case Some(matchedAtMaxIndex) =>
271+
keyWithIndexToMatched.put(currentKey, index, matchedAtMaxIndex)
272+
keyWithIndexToMatched.remove(currentKey, numValues - 1)
273+
274+
case None =>
275+
}
279276
} else {
280277
keyWithIndexToValue.remove(currentKey, 0)
281278
keyWithIndexToMatched.remove(currentKey, 0)
@@ -556,9 +553,9 @@ class SymmetricHashJoinStateManager(
556553

557554
protected val stateStore = getStateStore(keyWithIndexSchema, booleanValueSchema)
558555

559-
def get(key: UnsafeRow, valueIndex: Long): Boolean = {
556+
def get(key: UnsafeRow, valueIndex: Long): Option[Boolean] = {
560557
val row = stateStore.get(keyWithIndexRow(key, valueIndex))
561-
if (row != null) row.getBoolean(0) else false
558+
if (row != null) Some(row.getBoolean(0)) else None
562559
}
563560

564561
/** Put matched for key at the given index */
@@ -601,8 +598,10 @@ object SymmetricHashJoinStateManager {
601598
* Designed for object reuse.
602599
*/
603600
case class KeyToValueAndMatched(
604-
var key: UnsafeRow = null, var value: UnsafeRow = null, var matched: Boolean = false) {
605-
def withNew(newKey: UnsafeRow, newValue: UnsafeRow, newMatched: Boolean): this.type = {
601+
var key: UnsafeRow = null,
602+
var value: UnsafeRow = null,
603+
var matched: Option[Boolean] = None) {
604+
def withNew(newKey: UnsafeRow, newValue: UnsafeRow, newMatched: Option[Boolean]): this.type = {
606605
this.key = newKey
607606
this.value = newValue
608607
this.matched = newMatched

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter
123123
def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0)
124124

125125
def append(key: Int, value: Int)(implicit manager: SymmetricHashJoinStateManager): Unit = {
126-
manager.append(toJoinKeyRow(key), toInputValue(value))
126+
manager.append(toJoinKeyRow(key), toInputValue(value), matched = false)
127127
}
128128

129129
def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = {

0 commit comments

Comments
 (0)