@@ -348,8 +348,11 @@ object AggUtils {
348348 /**
349349 * Plans a streaming session aggregation using the following progression:
350350 *
351- * - Partial Merge (group: all keys)
351+ * - Partial Aggregation
352352 * - all tuples will have aggregated columns with initial value
353+ * - Sort within partition (sort: all keys)
354+ * - SessionWindowStateStoreRestore (group: keys "without" session)
355+ * - This will play as "Partial Merge" in each partition
353356 * - Shuffle & Sort (distribution: keys "without" session, sort: all keys)
354357 * - SessionWindowStateStoreRestore (group: keys "without" session)
355358 * - merge input tuples with stored tuples (sessions) respecting sort order
@@ -391,9 +394,28 @@ object AggUtils {
391394 child = child)
392395 }
393396
397+ // sort happens here to merge sessions on each partition
398+ // this is to reduce amount of rows to shuffle
399+ val partialMerged1 : SparkPlan = {
400+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge ))
401+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
402+ MergingSessionsExec (
403+ requiredChildDistributionExpressions = None ,
404+ requiredChildDistributionOption = None ,
405+ groupingExpressions = groupingAttributes,
406+ sessionExpression = sessionExpression,
407+ aggregateExpressions = aggregateExpressions,
408+ aggregateAttributes = aggregateAttributes,
409+ initialInputBufferOffset = groupingAttributes.length,
410+ resultExpressions = groupingAttributes ++
411+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
412+ child = partialAggregate
413+ )
414+ }
415+
394416 // shuffle & sort happens here: most of details are also handled in this physical plan
395417 val restored = SessionWindowStateStoreRestoreExec (groupingWithoutSessionAttributes,
396- sessionExpression.toAttribute, stateInfo = None , eventTimeWatermark = None , partialAggregate )
418+ sessionExpression.toAttribute, stateInfo = None , eventTimeWatermark = None , partialMerged1 )
397419
398420 val mergedSessions = {
399421 val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge ))
0 commit comments