@@ -20,18 +20,17 @@ package org.apache.spark.sql.execution.streaming
2020import java .util .UUID
2121import java .util .concurrent .TimeUnit ._
2222
23- import scala .collection .JavaConverters ._
24-
2523import org .apache .spark .rdd .RDD
2624import org .apache .spark .sql .catalyst .InternalRow
2725import org .apache .spark .sql .catalyst .errors ._
2826import org .apache .spark .sql .catalyst .expressions ._
29- import org .apache .spark .sql .catalyst .expressions .codegen .{GenerateUnsafeProjection , GenerateUnsafeRowJoiner , Predicate }
27+ import org .apache .spark .sql .catalyst .expressions .codegen .{GenerateUnsafeProjection , Predicate }
3028import org .apache .spark .sql .catalyst .plans .logical .EventTimeWatermark
3129import org .apache .spark .sql .catalyst .plans .physical .{AllTuples , ClusteredDistribution , Distribution , Partitioning }
3230import org .apache .spark .sql .catalyst .streaming .InternalOutputModes ._
3331import org .apache .spark .sql .execution ._
3432import org .apache .spark .sql .execution .metric .{SQLMetric , SQLMetrics }
33+ import org .apache .spark .sql .execution .streaming .StatefulOperatorsHelper .StreamingAggregationStateManager
3534import org .apache .spark .sql .execution .streaming .state ._
3635import org .apache .spark .sql .streaming .{OutputMode , StateOperatorProgress }
3736import org .apache .spark .sql .types ._
@@ -204,35 +203,18 @@ case class StateStoreRestoreExec(
204203 child : SparkPlan )
205204 extends UnaryExecNode with StateStoreReader {
206205
207- val removeRedundant : Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation
208- if (removeRedundant) {
209- log.info(" Advanced option removeRedundantInStatefulAggregation activated!" )
210- }
211-
212- val valueExpressions : Seq [Attribute ] = if (removeRedundant) {
213- child.output.diff(keyExpressions)
214- } else {
215- child.output
216- }
217- val keyValueJoinedExpressions : Seq [Attribute ] = keyExpressions ++ valueExpressions
218- val needToProjectToRestoreValue : Boolean = keyValueJoinedExpressions != child.output
219-
220206 override protected def doExecute (): RDD [InternalRow ] = {
221207 val numOutputRows = longMetric(" numOutputRows" )
208+ val stateManager = StreamingAggregationStateManager .newImpl(keyExpressions, child.output,
209+ sqlContext.conf)
222210
223211 child.execute().mapPartitionsWithStateStore(
224212 getStateInfo,
225213 keyExpressions.toStructType,
226- valueExpressions .toStructType,
214+ stateManager.getValueExpressions .toStructType,
227215 indexOrdinal = None ,
228216 sqlContext.sessionState,
229217 Some (sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
230- val getKey = GenerateUnsafeProjection .generate(keyExpressions, child.output)
231- val joiner = GenerateUnsafeRowJoiner .create(StructType .fromAttributes(keyExpressions),
232- StructType .fromAttributes(valueExpressions))
233- val restoreValueProject = GenerateUnsafeProjection .generate(
234- keyValueJoinedExpressions, child.output)
235-
236218 val hasInput = iter.hasNext
237219 if (! hasInput && keyExpressions.isEmpty) {
238220 // If our `keyExpressions` are empty, we're getting a global aggregation. In that case
@@ -243,23 +225,8 @@ case class StateStoreRestoreExec(
243225 store.iterator().map(_.value)
244226 } else {
245227 iter.flatMap { row =>
246- val key = getKey(row)
247- val savedState = store.get(key)
248- val restoredRow = if (removeRedundant) {
249- if (savedState == null ) {
250- savedState
251- } else {
252- val joinedRow = joiner.join(key, savedState)
253- if (needToProjectToRestoreValue) {
254- restoreValueProject(joinedRow)
255- } else {
256- joinedRow
257- }
258- }
259- } else {
260- savedState
261- }
262-
228+ val key = stateManager.extractKey(row)
229+ val restoredRow = stateManager.get(store, key)
263230 numOutputRows += 1
264231 Option (restoredRow).toSeq :+ row
265232 }
@@ -291,38 +258,21 @@ case class StateStoreSaveExec(
291258 child : SparkPlan )
292259 extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
293260
294- val removeRedundant : Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation
295- if (removeRedundant) {
296- log.info(" Advanced option removeRedundantInStatefulAggregation activated!" )
297- }
298-
299- val valueExpressions : Seq [Attribute ] = if (removeRedundant) {
300- child.output.diff(keyExpressions)
301- } else {
302- child.output
303- }
304- val keyValueJoinedExpressions : Seq [Attribute ] = keyExpressions ++ valueExpressions
305- val needToProjectToRestoreValue : Boolean = keyValueJoinedExpressions != child.output
306-
307261 override protected def doExecute (): RDD [InternalRow ] = {
308262 metrics // force lazy init at driver
309263 assert(outputMode.nonEmpty,
310264 " Incorrect planning in IncrementalExecution, outputMode has not been set" )
311265
266+ val stateManager = StreamingAggregationStateManager .newImpl(keyExpressions, child.output,
267+ sqlContext.conf)
268+
312269 child.execute().mapPartitionsWithStateStore(
313270 getStateInfo,
314271 keyExpressions.toStructType,
315- valueExpressions .toStructType,
272+ stateManager.getValueExpressions .toStructType,
316273 indexOrdinal = None ,
317274 sqlContext.sessionState,
318275 Some (sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
319- val getKey = GenerateUnsafeProjection .generate(keyExpressions, child.output)
320- val getValue = GenerateUnsafeProjection .generate(valueExpressions, child.output)
321- val joiner = GenerateUnsafeRowJoiner .create(StructType .fromAttributes(keyExpressions),
322- StructType .fromAttributes(valueExpressions))
323- val restoreValueProject = GenerateUnsafeProjection .generate(
324- keyValueJoinedExpressions, child.output)
325-
326276 val numOutputRows = longMetric(" numOutputRows" )
327277 val numUpdatedStateRows = longMetric(" numUpdatedStateRows" )
328278 val allUpdatesTimeMs = longMetric(" allUpdatesTimeMs" )
@@ -335,13 +285,7 @@ case class StateStoreSaveExec(
335285 allUpdatesTimeMs += timeTakenMs {
336286 while (iter.hasNext) {
337287 val row = iter.next().asInstanceOf [UnsafeRow ]
338- val key = getKey(row)
339- val value = if (removeRedundant) {
340- getValue(row)
341- } else {
342- row
343- }
344- store.put(key, value)
288+ stateManager.put(store, row)
345289 numUpdatedStateRows += 1
346290 }
347291 }
@@ -352,18 +296,7 @@ case class StateStoreSaveExec(
352296 setStoreMetrics(store)
353297 store.iterator().map { rowPair =>
354298 numOutputRows += 1
355-
356- if (removeRedundant) {
357- val joinedRow = joiner.join(rowPair.key, rowPair.value)
358- if (needToProjectToRestoreValue) {
359- restoreValueProject(joinedRow)
360- } else {
361- joinedRow
362- }
363- } else {
364- rowPair.value
365- }
366-
299+ stateManager.restoreOriginRow(rowPair)
367300 }
368301
369302 // Update and output only rows being evicted from the StateStore
@@ -373,13 +306,7 @@ case class StateStoreSaveExec(
373306 val filteredIter = iter.filter(row => ! watermarkPredicateForData.get.eval(row))
374307 while (filteredIter.hasNext) {
375308 val row = filteredIter.next().asInstanceOf [UnsafeRow ]
376- val key = getKey(row)
377- val value = if (removeRedundant) {
378- getValue(row)
379- } else {
380- row
381- }
382- store.put(key, value)
309+ stateManager.put(store, row)
383310 numUpdatedStateRows += 1
384311 }
385312 }
@@ -394,17 +321,7 @@ case class StateStoreSaveExec(
394321 val rowPair = rangeIter.next()
395322 if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
396323 store.remove(rowPair.key)
397-
398- if (removeRedundant) {
399- val joinedRow = joiner.join(rowPair.key, rowPair.value)
400- removedValueRow = if (needToProjectToRestoreValue) {
401- restoreValueProject(joinedRow)
402- } else {
403- joinedRow
404- }
405- } else {
406- removedValueRow = rowPair.value
407- }
324+ removedValueRow = stateManager.restoreOriginRow(rowPair)
408325 }
409326 }
410327 if (removedValueRow == null ) {
@@ -436,13 +353,7 @@ case class StateStoreSaveExec(
436353 override protected def getNext (): InternalRow = {
437354 if (baseIterator.hasNext) {
438355 val row = baseIterator.next().asInstanceOf [UnsafeRow ]
439- val key = getKey(row)
440- val value = if (removeRedundant) {
441- getValue(row)
442- } else {
443- row
444- }
445- store.put(key, value)
356+ stateManager.put(store, row)
446357 numOutputRows += 1
447358 numUpdatedStateRows += 1
448359 row
0 commit comments