Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
* - Apply the optional condition to filter the joined rows as the final output.
*
* If a timestamp column with event time watermark is present in the join keys or in the input
* data, then the it uses the watermark figure out which rows in the buffer will not join with
* and the new data, and therefore can be discarded. Depending on the provided query conditions, we
* data, then it uses the watermark to figure out which rows in the buffer will not join with
* the new data, and therefore can be discarded. Depending on the provided query conditions, we
* can define thresholds on both state key (i.e. joining keys) and state value (i.e. input rows).
* There are three kinds of queries possible regarding this as explained below.
* Assume that watermark has been defined on both `leftTime` and `rightTime` columns used below.
Expand Down Expand Up @@ -134,7 +134,7 @@ case class StreamingSymmetricHashJoinExec(
stateWatermarkPredicates: JoinStateWatermarkPredicates,
stateFormatVersion: Int,
left: SparkPlan,
right: SparkPlan) extends SparkPlan with BinaryExecNode with StateStoreWriter {
right: SparkPlan) extends BinaryExecNode with StateStoreWriter {

def this(
leftKeys: Seq[Expression],
Expand All @@ -157,14 +157,16 @@ case class StreamingSymmetricHashJoinExec(
" the checkpoint and rerun the query. See SPARK-26154 for more details.")
}

private lazy val errorMessageForJoinType =
s"${getClass.getSimpleName} should not take $joinType as the JoinType"

private def throwBadJoinTypeException(): Nothing = {
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $joinType as the JoinType")
throw new IllegalArgumentException(errorMessageForJoinType)
}

require(
joinType == Inner || joinType == LeftOuter || joinType == RightOuter,
s"${getClass.getSimpleName} should not take $joinType as the JoinType")
errorMessageForJoinType)
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType))

private val storeConf = new StateStoreConf(sqlContext.conf)
Expand All @@ -189,11 +191,9 @@ case class StreamingSymmetricHashJoinExec(
override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => PartitioningCollection(Seq(left.outputPartitioning))
case RightOuter => PartitioningCollection(Seq(right.outputPartitioning))
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case _ => throwBadJoinTypeException()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuanyuanking - sorry that I don't get how to change, are you suggesting to have a string val for error message to be used in throwBadJoinTypeException and require(...): val errorMessageForJoinType = s"${getClass.getSimpleName} should not take $joinType as the JoinType"), or something else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, have a string val for the same error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuanyuanking - sure, updated.

}

override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
Expand Down Expand Up @@ -246,13 +246,14 @@ case class StreamingSymmetricHashJoinExec(

// Join one side input using the other side's buffered/state rows. Here is how it is done.
//
// - `leftJoiner.joinWith(rightJoiner)` generates all rows from matching new left input with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with this part, cc @zsxwing @HeartSaVioR @xuanyuanking

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment seems to be just modified for replacing leftJoiner.joinWith(rightJoiner) with leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) and vice versa for right side. Other parts aren't modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan , @HeartSaVioR - yes, this is just updating the comment, because there's no leftJoiner/rightJoiner/joinWith in the file, and the original author (#19271) should mean to refer to leftSideJoiner/rightSideJoiner/storeAndJoinWithOtherSide. I think it would make sense to be consistent between code and comment here. This is anyway a minor change for comment only.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original PR just wants to use pseudocode to explain, either way is ok to me.

// stored right input, and also stores all the left input
// - `leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner)` generates all rows from
// matching new left input with stored right input, and also stores all the left input
//
// - `rightJoiner.joinWith(leftJoiner)` generates all rows from matching new right input with
// stored left input, and also stores all the right input. It also generates all rows from
// matching new left input with new right input, since the new left input has become stored
// by that point. This tiny asymmetry is necessary to avoid duplication.
// - `rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner)` generates all rows from
// matching new right input with stored left input, and also stores all the right input.
// It also generates all rows from matching new left input with new right input, since
// the new left input has become stored by that point. This tiny asymmetry is necessary
// to avoid duplication.
val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) {
(input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched)
}
Expand Down Expand Up @@ -459,8 +460,9 @@ case class StreamingSymmetricHashJoinExec(
*/
def storeAndJoinWithOtherSide(
otherSideJoiner: OneSideHashJoiner)(
generateJoinedRow: (InternalRow, InternalRow) => JoinedRow):
Iterator[InternalRow] = {
generateJoinedRow: (InternalRow, InternalRow) => JoinedRow)
: Iterator[InternalRow] = {

val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey))
val nonLateRows =
WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match {
Expand All @@ -471,6 +473,14 @@ case class StreamingSymmetricHashJoinExec(
inputIter
}

val generateFilteredJoinedRow: InternalRow => Iterator[InternalRow] = joinSide match {
case LeftSide if joinType == LeftOuter =>
(row: InternalRow) => Iterator(generateJoinedRow(row, nullRight))
case RightSide if joinType == RightOuter =>
(row: InternalRow) => Iterator(generateJoinedRow(row, nullLeft))
case _ => (_: InternalRow) => Iterator.empty
}

nonLateRows.flatMap { row =>
val thisRow = row.asInstanceOf[UnsafeRow]
// If this row fails the pre join filter, that means it can never satisfy the full join
Expand All @@ -483,13 +493,7 @@ case class StreamingSymmetricHashJoinExec(
.getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter)
new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter)
} else {
joinSide match {
case LeftSide if joinType == LeftOuter =>
Iterator(generateJoinedRow(thisRow, nullRight))
case RightSide if joinType == RightOuter =>
Iterator(generateJoinedRow(thisRow, nullLeft))
case _ => Iterator()
}
generateFilteredJoinedRow(thisRow)
}
}
}
Expand Down