Skip to content

Commit 977428c

Browse files
committed
Refine code change: introduce trait and classes to group duplicate methods
1 parent abec57f commit 977428c

File tree

2 files changed

+152
-105
lines changed

2 files changed

+152
-105
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming
19+
20+
import org.apache.spark.internal.Logging
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
24+
import org.apache.spark.sql.execution.streaming.state.{StateStore, UnsafeRowPair}
25+
import org.apache.spark.sql.internal.SQLConf
26+
import org.apache.spark.sql.types.StructType
27+
28+
object StatefulOperatorsHelper {
29+
sealed trait StreamingAggregationStateManager extends Serializable {
30+
def extractKey(row: InternalRow): UnsafeRow
31+
def getValueExpressions: Seq[Attribute]
32+
def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow
33+
def get(store: StateStore, key: UnsafeRow): UnsafeRow
34+
def put(store: StateStore, row: UnsafeRow): Unit
35+
}
36+
37+
object StreamingAggregationStateManager extends Logging {
38+
def newImpl(
39+
keyExpressions: Seq[Attribute],
40+
childOutput: Seq[Attribute],
41+
conf: SQLConf): StreamingAggregationStateManager = {
42+
43+
if (conf.advancedRemoveRedundantInStatefulAggregation) {
44+
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
45+
new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput)
46+
} else {
47+
new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput)
48+
}
49+
}
50+
}
51+
52+
abstract class StreamingAggregationStateManagerBaseImpl(
53+
protected val keyExpressions: Seq[Attribute],
54+
protected val childOutput: Seq[Attribute]) extends StreamingAggregationStateManager {
55+
56+
@transient protected lazy val keyProjector =
57+
GenerateUnsafeProjection.generate(keyExpressions, childOutput)
58+
59+
def extractKey(row: InternalRow): UnsafeRow = keyProjector(row)
60+
}
61+
62+
class StreamingAggregationStateManagerImplV1(
63+
keyExpressions: Seq[Attribute],
64+
childOutput: Seq[Attribute])
65+
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) {
66+
67+
override def getValueExpressions: Seq[Attribute] = {
68+
childOutput
69+
}
70+
71+
override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
72+
rowPair.value
73+
}
74+
75+
override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
76+
store.get(key)
77+
}
78+
79+
override def put(store: StateStore, row: UnsafeRow): Unit = {
80+
store.put(extractKey(row), row)
81+
}
82+
}
83+
84+
class StreamingAggregationStateManagerImplV2(
85+
keyExpressions: Seq[Attribute],
86+
childOutput: Seq[Attribute])
87+
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) {
88+
89+
private val valueExpressions: Seq[Attribute] = childOutput.diff(keyExpressions)
90+
private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
91+
private val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != childOutput
92+
93+
@transient private lazy val valueProjector =
94+
GenerateUnsafeProjection.generate(valueExpressions, childOutput)
95+
96+
@transient private lazy val joiner =
97+
GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
98+
StructType.fromAttributes(valueExpressions))
99+
@transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate(
100+
keyValueJoinedExpressions, childOutput)
101+
102+
override def getValueExpressions: Seq[Attribute] = {
103+
valueExpressions
104+
}
105+
106+
override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
107+
val joinedRow = joiner.join(rowPair.key, rowPair.value)
108+
if (needToProjectToRestoreValue) {
109+
restoreValueProjector(joinedRow)
110+
} else {
111+
joinedRow
112+
}
113+
}
114+
115+
override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
116+
val savedState = store.get(key)
117+
if (savedState == null) {
118+
return savedState
119+
}
120+
121+
val joinedRow = joiner.join(key, savedState)
122+
if (needToProjectToRestoreValue) {
123+
restoreValueProjector(joinedRow)
124+
} else {
125+
joinedRow
126+
}
127+
}
128+
129+
override def put(store: StateStore, row: UnsafeRow): Unit = {
130+
val key = keyProjector(row)
131+
val value = valueProjector(row)
132+
store.put(key, value)
133+
}
134+
}
135+
136+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala

Lines changed: 16 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,17 @@ package org.apache.spark.sql.execution.streaming
2020
import java.util.UUID
2121
import java.util.concurrent.TimeUnit._
2222

23-
import scala.collection.JavaConverters._
24-
2523
import org.apache.spark.rdd.RDD
2624
import org.apache.spark.sql.catalyst.InternalRow
2725
import org.apache.spark.sql.catalyst.errors._
2826
import org.apache.spark.sql.catalyst.expressions._
29-
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner, Predicate}
27+
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
3028
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
3129
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
3230
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
3331
import org.apache.spark.sql.execution._
3432
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
33+
import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager
3534
import org.apache.spark.sql.execution.streaming.state._
3635
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
3736
import org.apache.spark.sql.types._
@@ -204,35 +203,18 @@ case class StateStoreRestoreExec(
204203
child: SparkPlan)
205204
extends UnaryExecNode with StateStoreReader {
206205

207-
val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation
208-
if (removeRedundant) {
209-
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
210-
}
211-
212-
val valueExpressions: Seq[Attribute] = if (removeRedundant) {
213-
child.output.diff(keyExpressions)
214-
} else {
215-
child.output
216-
}
217-
val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
218-
val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output
219-
220206
override protected def doExecute(): RDD[InternalRow] = {
221207
val numOutputRows = longMetric("numOutputRows")
208+
val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output,
209+
sqlContext.conf)
222210

223211
child.execute().mapPartitionsWithStateStore(
224212
getStateInfo,
225213
keyExpressions.toStructType,
226-
valueExpressions.toStructType,
214+
stateManager.getValueExpressions.toStructType,
227215
indexOrdinal = None,
228216
sqlContext.sessionState,
229217
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
230-
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
231-
val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
232-
StructType.fromAttributes(valueExpressions))
233-
val restoreValueProject = GenerateUnsafeProjection.generate(
234-
keyValueJoinedExpressions, child.output)
235-
236218
val hasInput = iter.hasNext
237219
if (!hasInput && keyExpressions.isEmpty) {
238220
// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
@@ -243,23 +225,8 @@ case class StateStoreRestoreExec(
243225
store.iterator().map(_.value)
244226
} else {
245227
iter.flatMap { row =>
246-
val key = getKey(row)
247-
val savedState = store.get(key)
248-
val restoredRow = if (removeRedundant) {
249-
if (savedState == null) {
250-
savedState
251-
} else {
252-
val joinedRow = joiner.join(key, savedState)
253-
if (needToProjectToRestoreValue) {
254-
restoreValueProject(joinedRow)
255-
} else {
256-
joinedRow
257-
}
258-
}
259-
} else {
260-
savedState
261-
}
262-
228+
val key = stateManager.extractKey(row)
229+
val restoredRow = stateManager.get(store, key)
263230
numOutputRows += 1
264231
Option(restoredRow).toSeq :+ row
265232
}
@@ -291,38 +258,21 @@ case class StateStoreSaveExec(
291258
child: SparkPlan)
292259
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
293260

294-
val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation
295-
if (removeRedundant) {
296-
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
297-
}
298-
299-
val valueExpressions: Seq[Attribute] = if (removeRedundant) {
300-
child.output.diff(keyExpressions)
301-
} else {
302-
child.output
303-
}
304-
val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
305-
val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output
306-
307261
override protected def doExecute(): RDD[InternalRow] = {
308262
metrics // force lazy init at driver
309263
assert(outputMode.nonEmpty,
310264
"Incorrect planning in IncrementalExecution, outputMode has not been set")
311265

266+
val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output,
267+
sqlContext.conf)
268+
312269
child.execute().mapPartitionsWithStateStore(
313270
getStateInfo,
314271
keyExpressions.toStructType,
315-
valueExpressions.toStructType,
272+
stateManager.getValueExpressions.toStructType,
316273
indexOrdinal = None,
317274
sqlContext.sessionState,
318275
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
319-
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
320-
val getValue = GenerateUnsafeProjection.generate(valueExpressions, child.output)
321-
val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
322-
StructType.fromAttributes(valueExpressions))
323-
val restoreValueProject = GenerateUnsafeProjection.generate(
324-
keyValueJoinedExpressions, child.output)
325-
326276
val numOutputRows = longMetric("numOutputRows")
327277
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
328278
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
@@ -335,13 +285,7 @@ case class StateStoreSaveExec(
335285
allUpdatesTimeMs += timeTakenMs {
336286
while (iter.hasNext) {
337287
val row = iter.next().asInstanceOf[UnsafeRow]
338-
val key = getKey(row)
339-
val value = if (removeRedundant) {
340-
getValue(row)
341-
} else {
342-
row
343-
}
344-
store.put(key, value)
288+
stateManager.put(store, row)
345289
numUpdatedStateRows += 1
346290
}
347291
}
@@ -352,18 +296,7 @@ case class StateStoreSaveExec(
352296
setStoreMetrics(store)
353297
store.iterator().map { rowPair =>
354298
numOutputRows += 1
355-
356-
if (removeRedundant) {
357-
val joinedRow = joiner.join(rowPair.key, rowPair.value)
358-
if (needToProjectToRestoreValue) {
359-
restoreValueProject(joinedRow)
360-
} else {
361-
joinedRow
362-
}
363-
} else {
364-
rowPair.value
365-
}
366-
299+
stateManager.restoreOriginRow(rowPair)
367300
}
368301

369302
// Update and output only rows being evicted from the StateStore
@@ -373,13 +306,7 @@ case class StateStoreSaveExec(
373306
val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row))
374307
while (filteredIter.hasNext) {
375308
val row = filteredIter.next().asInstanceOf[UnsafeRow]
376-
val key = getKey(row)
377-
val value = if (removeRedundant) {
378-
getValue(row)
379-
} else {
380-
row
381-
}
382-
store.put(key, value)
309+
stateManager.put(store, row)
383310
numUpdatedStateRows += 1
384311
}
385312
}
@@ -394,17 +321,7 @@ case class StateStoreSaveExec(
394321
val rowPair = rangeIter.next()
395322
if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
396323
store.remove(rowPair.key)
397-
398-
if (removeRedundant) {
399-
val joinedRow = joiner.join(rowPair.key, rowPair.value)
400-
removedValueRow = if (needToProjectToRestoreValue) {
401-
restoreValueProject(joinedRow)
402-
} else {
403-
joinedRow
404-
}
405-
} else {
406-
removedValueRow = rowPair.value
407-
}
324+
removedValueRow = stateManager.restoreOriginRow(rowPair)
408325
}
409326
}
410327
if (removedValueRow == null) {
@@ -436,13 +353,7 @@ case class StateStoreSaveExec(
436353
override protected def getNext(): InternalRow = {
437354
if (baseIterator.hasNext) {
438355
val row = baseIterator.next().asInstanceOf[UnsafeRow]
439-
val key = getKey(row)
440-
val value = if (removeRedundant) {
441-
getValue(row)
442-
} else {
443-
row
444-
}
445-
store.put(key, value)
356+
stateManager.put(store, row)
446357
numOutputRows += 1
447358
numUpdatedStateRows += 1
448359
row

0 commit comments

Comments
 (0)