diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 3d071df493cec..a52f5f4ac94ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -56,8 +56,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} * - Apply the optional condition to filter the joined rows as the final output. * * If a timestamp column with event time watermark is present in the join keys or in the input - * data, then the it uses the watermark figure out which rows in the buffer will not join with - * and the new data, and therefore can be discarded. Depending on the provided query conditions, we + * data, then it uses the watermark to figure out which rows in the buffer will not join with + * the new data, and therefore can be discarded. Depending on the provided query conditions, we * can define thresholds on both state key (i.e. joining keys) and state value (i.e. input rows). * There are three kinds of queries possible regarding this as explained below. * Assume that watermark has been defined on both `leftTime` and `rightTime` columns used below. @@ -134,7 +134,7 @@ case class StreamingSymmetricHashJoinExec( stateWatermarkPredicates: JoinStateWatermarkPredicates, stateFormatVersion: Int, left: SparkPlan, - right: SparkPlan) extends SparkPlan with BinaryExecNode with StateStoreWriter { + right: SparkPlan) extends BinaryExecNode with StateStoreWriter { def this( leftKeys: Seq[Expression], @@ -157,14 +157,16 @@ case class StreamingSymmetricHashJoinExec( " the checkpoint and rerun the query. See SPARK-26154 for more details.") } + private lazy val errorMessageForJoinType = + s"${getClass.getSimpleName} should not take $joinType as the JoinType" + private def throwBadJoinTypeException(): Nothing = { - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $joinType as the JoinType") + throw new IllegalArgumentException(errorMessageForJoinType) } require( joinType == Inner || joinType == LeftOuter || joinType == RightOuter, - s"${getClass.getSimpleName} should not take $joinType as the JoinType") + errorMessageForJoinType) require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) private val storeConf = new StateStoreConf(sqlContext.conf) @@ -189,11 +191,9 @@ case class StreamingSymmetricHashJoinExec( override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - case LeftOuter => PartitioningCollection(Seq(left.outputPartitioning)) - case RightOuter => PartitioningCollection(Seq(right.outputPartitioning)) - case x => - throw new IllegalArgumentException( - s"${getClass.getSimpleName} should not take $x as the JoinType") + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case _ => throwBadJoinTypeException() } override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { @@ -246,13 +246,14 @@ case class StreamingSymmetricHashJoinExec( // Join one side input using the other side's buffered/state rows. Here is how it is done. // - // - `leftJoiner.joinWith(rightJoiner)` generates all rows from matching new left input with - // stored right input, and also stores all the left input + // - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)` generates all rows from + // matching new left input with stored right input, and also stores all the left input // - // - `rightJoiner.joinWith(leftJoiner)` generates all rows from matching new right input with - // stored left input, and also stores all the right input. It also generates all rows from - // matching new left input with new right input, since the new left input has become stored - // by that point. This tiny asymmetry is necessary to avoid duplication. + // - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)` generates all rows from + // matching new right input with stored left input, and also stores all the right input. + // It also generates all rows from matching new left input with new right input, since + // the new left input has become stored by that point. This tiny asymmetry is necessary + // to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched) } @@ -459,8 +460,9 @@ case class StreamingSymmetricHashJoinExec( */ def storeAndJoinWithOtherSide( otherSideJoiner: OneSideHashJoiner)( - generateJoinedRow: (InternalRow, InternalRow) => JoinedRow): - Iterator[InternalRow] = { + generateJoinedRow: (InternalRow, InternalRow) => JoinedRow) + : Iterator[InternalRow] = { + val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey)) val nonLateRows = WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { @@ -471,6 +473,14 @@ case class StreamingSymmetricHashJoinExec( inputIter } + val generateFilteredJoinedRow: InternalRow => Iterator[InternalRow] = joinSide match { + case LeftSide if joinType == LeftOuter => + (row: InternalRow) => Iterator(generateJoinedRow(row, nullRight)) + case RightSide if joinType == RightOuter => + (row: InternalRow) => Iterator(generateJoinedRow(row, nullLeft)) + case _ => (_: InternalRow) => Iterator.empty + } + nonLateRows.flatMap { row => val thisRow = row.asInstanceOf[UnsafeRow] // If this row fails the pre join filter, that means it can never satisfy the full join @@ -483,13 +493,7 @@ case class StreamingSymmetricHashJoinExec( .getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter) new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter) } else { - joinSide match { - case LeftSide if joinType == LeftOuter => - Iterator(generateJoinedRow(thisRow, nullRight)) - case RightSide if joinType == RightOuter => - Iterator(generateJoinedRow(thisRow, nullLeft)) - case _ => Iterator() - } + generateFilteredJoinedRow(thisRow) } } }