Skip to content

Commit 4efa58b

Browse files
committed
[SPARK-3660][STREAMING] Initial RDD for updateStateByKey transformation
1 parent 8f40ca0 commit 4efa58b

File tree

4 files changed

+94
-82
lines changed

4 files changed

+94
-82
lines changed

examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
package org.apache.spark.examples.streaming
1919

2020
import org.apache.spark.SparkConf
21+
import org.apache.spark.HashPartitioner
2122
import org.apache.spark.streaming._
2223
import org.apache.spark.streaming.StreamingContext._
2324

2425
/**
2526
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
26-
* second.
27+
* second starting with initial value of word count.
2728
* Usage: StatefulNetworkWordCount <hostname> <port>
2829
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive
2930
* data.
@@ -51,11 +52,18 @@ object StatefulNetworkWordCount {
5152
Some(currentCount + previousCount)
5253
}
5354

55+
val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
56+
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
57+
}
58+
5459
val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
5560
// Create the context with a 1 second batch size
5661
val ssc = new StreamingContext(sparkConf, Seconds(1))
5762
ssc.checkpoint(".")
5863

64+
// Initial RDD input to updateStateByKey
65+
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
66+
5967
// Create a NetworkInputDStream on target ip:port and count the
6068
// words in input stream of \n delimited test (eg. generated by 'nc')
6169
val lines = ssc.socketTextStream(args(0), args(1).toInt)
@@ -64,7 +72,8 @@ object StatefulNetworkWordCount {
6472

6573
// Update the cumulative count using updateStateByKey
6674
// This will give a Dstream made of state (which is the cumulative count of the words)
67-
val stateDstream = wordDstream.updateStateByKey[Int](updateFunc)
75+
val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
76+
new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
6877
stateDstream.print()
6978
ssc.start()
7079
ssc.awaitTermination()

examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCountWithInitial.scala

Lines changed: 0 additions & 80 deletions
This file was deleted.

streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.streaming.StreamingContext._
2121

2222
import org.apache.spark.rdd.{BlockRDD, RDD}
2323
import org.apache.spark.SparkContext._
24+
import org.apache.spark.HashPartitioner
2425

2526
import util.ManualClock
2627
import org.apache.spark.{SparkException, SparkConf}
@@ -349,6 +350,43 @@ class BasicOperationsSuite extends TestSuiteBase {
349350
testOperation(inputData, updateStateOperation, outputData, true)
350351
}
351352

353+
test("updateStateByKey - with initial value RDD") {
354+
val initial = Seq(("a", 1), ("c", 2))
355+
356+
val inputData =
357+
Seq(
358+
Seq("a"),
359+
Seq("a", "b"),
360+
Seq("a", "b", "c"),
361+
Seq("a", "b"),
362+
Seq("a"),
363+
Seq()
364+
)
365+
366+
val outputData =
367+
Seq(
368+
Seq(("a", 2), ("c", 2)),
369+
Seq(("a", 3), ("b", 1), ("c", 2)),
370+
Seq(("a", 4), ("b", 2), ("c", 3)),
371+
Seq(("a", 5), ("b", 3), ("c", 3)),
372+
Seq(("a", 6), ("b", 3), ("c", 3)),
373+
Seq(("a", 6), ("b", 3), ("c", 3))
374+
)
375+
376+
val updateStateOperation = (s: DStream[String], initialRDD : RDD[(String, Int)]) => {
377+
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
378+
Some(values.sum + state.getOrElse(0))
379+
}
380+
val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
381+
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
382+
}
383+
s.map(x => (x, 1)).updateStateByKey[Int](newUpdateFunc,
384+
new HashPartitioner (numInputPartitions), true, initialRDD)
385+
}
386+
387+
testOperationWithInitial(initial, inputData, updateStateOperation, outputData, true)
388+
}
389+
352390
test("updateStateByKey - object lifecycle") {
353391
val inputData =
354392
Seq(

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,34 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
212212
ssc
213213
}
214214

215+
/**
216+
* Set up required DStreams to test the DStream operation using the sequence
217+
* of input collections, and initial sequence.
218+
*/
219+
def setupStreamsWithInitial[U: ClassTag, V: ClassTag](
220+
initial: Seq[V],
221+
input: Seq[Seq[U]],
222+
operation: (DStream[U], RDD[V]) => DStream[V],
223+
numPartitions: Int = numInputPartitions
224+
): StreamingContext = {
225+
// Create StreamingContext
226+
val ssc = new StreamingContext(conf, batchDuration)
227+
if (checkpointDir != null) {
228+
ssc.checkpoint(checkpointDir)
229+
}
230+
231+
// Create initial value RDD
232+
val initialRDD = ssc.sc.makeRDD(initial, numInputPartitions)
233+
234+
// Setup the stream computation
235+
val inputStream = new TestInputStream(ssc, input, numPartitions)
236+
val operatedStream = operation(inputStream, initialRDD)
237+
val outputStream = new TestOutputStreamWithPartitions(operatedStream,
238+
new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
239+
outputStream.register()
240+
ssc
241+
}
242+
215243
/**
216244
* Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
217245
* returns the collected output. It will wait until `numExpectedOutput` number of
@@ -321,6 +349,23 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
321349
logInfo("Output verified successfully")
322350
}
323351

352+
/**
353+
* Test unary DStream operation with a list of inputs, initial values, with number of
354+
* batches to run same as the number of expected output values
355+
*/
356+
def testOperationWithInitial[U: ClassTag, V: ClassTag](
357+
initial: Seq[V],
358+
input: Seq[Seq[U]],
359+
operation: (DStream[U], RDD[V]) => DStream[V],
360+
expectedOutput: Seq[Seq[V]],
361+
useSet: Boolean = false
362+
) {
363+
val numBatches_ = expectedOutput.size
364+
val ssc = setupStreamsWithInitial[U, V](initial, input, operation)
365+
val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
366+
verifyOutput[V](output, expectedOutput, useSet)
367+
}
368+
324369
/**
325370
* Test unary DStream operation with a list of inputs, with number of
326371
* batches to run same as the number of expected output values

0 commit comments

Comments
 (0)