Skip to content

Commit fb19879

Browse files
committed
WIP apply merging session in each partition before shuffling
1 parent bbbb844 commit fb19879

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,11 @@ object AggUtils {
348348
/**
349349
* Plans a streaming session aggregation using the following progression:
350350
*
351-
* - Partial Merge (group: all keys)
351+
* - Partial Aggregation
352352
* - all tuples will have aggregated columns with initial value
353+
* - Sort within partition (sort: all keys)
354+
* - SessionWindowStateStoreRestore (group: keys "without" session)
355+
* - This will play as "Partial Merge" in each partition
353356
* - Shuffle & Sort (distribution: keys "without" session, sort: all keys)
354357
* - SessionWindowStateStoreRestore (group: keys "without" session)
355358
* - merge input tuples with stored tuples (sessions) respecting sort order
@@ -391,9 +394,28 @@ object AggUtils {
391394
child = child)
392395
}
393396

397+
// sort happens here to merge sessions on each partition
398+
// this is to reduce amount of rows to shuffle
399+
val partialMerged1: SparkPlan = {
400+
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
401+
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
402+
MergingSessionsExec(
403+
requiredChildDistributionExpressions = None,
404+
requiredChildDistributionOption = None,
405+
groupingExpressions = groupingAttributes,
406+
sessionExpression = sessionExpression,
407+
aggregateExpressions = aggregateExpressions,
408+
aggregateAttributes = aggregateAttributes,
409+
initialInputBufferOffset = groupingAttributes.length,
410+
resultExpressions = groupingAttributes ++
411+
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
412+
child = partialAggregate
413+
)
414+
}
415+
394416
// shuffle & sort happens here: most of details are also handled in this physical plan
395417
val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes,
396-
sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialAggregate)
418+
sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, partialMerged1)
397419

398420
val mergedSessions = {
399421
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))

sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche
638638
testStream(sessionUpdates, OutputMode.Update())(
639639
AddData(inputData, ("hello world spark", 10L), ("world hello structured streaming", 11L)),
640640
// Advance watermark to 1 seconds
641+
// current sessions after batch:
641642
// ("hello", 10, 21, 11, 2)
642643
// ("world", 10, 21, 11, 2)
643644
// ("spark", 10, 20, 10, 1)

0 commit comments

Comments
 (0)