Skip to content

Commit d357b70

Browse files
committed
support windowed dstream
1 parent bd13026 commit d357b70

File tree

3 files changed

+247
-98
lines changed

3 files changed

+247
-98
lines changed

python/pyspark/streaming/dstream.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from pyspark.storagelevel import StorageLevel
2323
from pyspark.streaming.util import rddToFileName, RDDFunction, RDDFunction2
2424
from pyspark.rdd import portable_hash
25-
from pyspark.streaming.duration import Seconds
26-
25+
from pyspark.streaming.duration import Duration, Seconds
26+
from pyspark.resultiterable import ResultIterable
2727

2828
__all__ = ["DStream"]
2929

@@ -299,13 +299,17 @@ def get_output(rdd, time):
299299
return result
300300

301301
def transform(self, func):
302-
return TransformedRDD(self, lambda a, t: func(a), True)
302+
return TransformedDStream(self, lambda a, t: func(a), True)
303303

304304
def transformWithTime(self, func):
305-
return TransformedRDD(self, func, False)
305+
return TransformedDStream(self, func, False)
306306

307307
def transformWith(self, func, other, keepSerializer=False):
308-
return Transformed2RDD(self, lambda a, b, t: func(a, b), other, keepSerializer)
308+
jfunc = RDDFunction2(self.ctx, func, self._jrdd_deserializer)
309+
dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
310+
other._jdstream.dstream(), jfunc)
311+
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer
312+
return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer)
309313

310314
def repartitions(self, numPartitions):
311315
return self.transform(lambda rdd: rdd.repartition(numPartitions))
@@ -336,28 +340,60 @@ def window(self, windowDuration, slideDuration=None):
336340
s = Seconds(slideDuration)
337341
return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer)
338342

339-
def reduceByWindow(self, reduceFunc, inReduceFunc, windowDuration, slideDuration):
340-
pass
341-
342-
def countByWindow(self, window, slide):
343-
pass
344-
345-
def countByValueAndWindow(self, window, slide, numPartitions=None):
346-
pass
347-
348-
def groupByKeyAndWindow(self, window, slide, numPartitions=None):
349-
pass
350-
351-
def reduceByKeyAndWindow(self, reduceFunc, inReduceFunc, window, slide, numPartitions=None):
352-
pass
343+
def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration):
344+
keyed = self.map(lambda x: (1, x))
345+
reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc,
346+
windowDuration, slideDuration, 1)
347+
return reduced.map(lambda (k, v): v)
348+
349+
def countByWindow(self, windowDuration, slideDuration):
350+
return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub,
351+
windowDuration, slideDuration)
352+
353+
def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None):
354+
keyed = self.map(lambda x: (x, 1))
355+
counted = keyed.reduceByKeyAndWindow(lambda a, b: a + b, lambda a, b: a - b,
356+
windowDuration, slideDuration, numPartitions)
357+
return counted.filter(lambda (k, v): v > 0).count()
358+
359+
def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None):
360+
ls = self.mapValues(lambda x: [x])
361+
grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):],
362+
windowDuration, slideDuration, numPartitions)
363+
return grouped.mapValues(ResultIterable)
364+
365+
def reduceByKeyAndWindow(self, func, invFunc,
366+
windowDuration, slideDuration, numPartitions=None):
367+
reduced = self.reduceByKey(func)
368+
369+
def reduceFunc(a, t):
370+
return a.reduceByKey(func, numPartitions)
371+
372+
def invReduceFunc(a, b, t):
373+
b = b.reduceByKey(func, numPartitions)
374+
joined = a.leftOuterJoin(b, numPartitions)
375+
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
376+
377+
if not isinstance(windowDuration, Duration):
378+
windowDuration = Seconds(windowDuration)
379+
if not isinstance(slideDuration, Duration):
380+
slideDuration = Seconds(slideDuration)
381+
serializer = reduced._jrdd_deserializer
382+
jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
383+
jinvReduceFunc = RDDFunction2(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
384+
dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
385+
jreduceFunc, jinvReduceFunc,
386+
windowDuration._jduration,
387+
slideDuration._jduration)
388+
return DStream(dstream.asJavaDStream(), self._ssc, serializer)
353389

354390
def updateStateByKey(self, updateFunc):
355391
# FIXME: convert updateFunc to java JFunction2
356392
jFunc = updateFunc
357393
return self._jdstream.updateStateByKey(jFunc)
358394

359395

360-
class TransformedRDD(DStream):
396+
class TransformedDStream(DStream):
361397
def __init__(self, prev, func, reuse=False):
362398
ssc = prev._ssc
363399
self._ssc = ssc
@@ -366,7 +402,8 @@ def __init__(self, prev, func, reuse=False):
366402
self.is_cached = False
367403
self.is_checkpointed = False
368404

369-
if isinstance(prev, TransformedRDD) and not prev.is_cached and not prev.is_checkpointed:
405+
if (isinstance(prev, TransformedDStream) and
406+
not prev.is_cached and not prev.is_checkpointed):
370407
prev_func = prev.func
371408
old_func = func
372409
func = lambda rdd, t: old_func(prev_func(rdd, t), t)
@@ -388,13 +425,3 @@ def _jdstream(self):
388425
jfunc, self.reuse).asJavaDStream()
389426
self._jdstream_val = jdstream
390427
return jdstream
391-
392-
393-
class Transformed2RDD(DStream):
394-
def __init__(self, prev, func, other, keepSerializer=False):
395-
ssc = prev._ssc
396-
jfunc = RDDFunction2(ssc._sc, func, prev._jrdd_deserializer)
397-
jdstream = ssc._jvm.PythonTransformed2DStream(prev._jdstream.dstream(),
398-
other._jdstream.dstream(), jfunc)
399-
jrdd_serializer = prev._jrdd_deserializer if keepSerializer else ssc._sc.serializer
400-
DStream.__init__(self, jdstream.asJavaDStream(), ssc, jrdd_serializer)

python/pyspark/streaming/tests.py

Lines changed: 72 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -33,42 +33,64 @@
3333

3434

3535
class PySparkStreamingTestCase(unittest.TestCase):
36+
37+
timeout = 10 # seconds
38+
3639
def setUp(self):
3740
class_name = self.__class__.__name__
3841
self.sc = SparkContext(appName=class_name)
42+
self.sc.setCheckpointDir("/tmp")
3943
self.ssc = StreamingContext(self.sc, duration=Seconds(1))
4044

4145
def tearDown(self):
42-
# Do not call pyspark.streaming.context.StreamingContext.stop directly because
43-
# we do not wait to shutdown py4j client.
4446
self.ssc.stop()
4547
self.sc.stop()
46-
time.sleep(1)
4748

4849
@classmethod
4950
def tearDownClass(cls):
5051
# Make sure tp shutdown the callback server
5152
SparkContext._gateway._shutdown_callback_server()
5253

54+
def _test_func(self, input, func, expected, numSlices=None, sort=False):
55+
"""
56+
Start stream and return the result.
57+
@param input: dataset for the test. This should be list of lists.
58+
@param func: wrapped function. This function should return PythonDStream object.
59+
@param expected: expected output for this testcase.
60+
@param numSlices: the number of slices in the rdd in the dstream.
61+
"""
62+
# Generate input stream with user-defined input.
63+
input_stream = self.ssc._makeStream(input, numSlices)
64+
# Apply test function to stream.
65+
stream = func(input_stream)
66+
result = stream.collect()
67+
self.ssc.start()
5368

54-
class TestBasicOperations(PySparkStreamingTestCase):
55-
"""
56-
2 tests for each function for batach deserializer and unbatch deserilizer because
57-
the deserializer is not changed dunamically after streaming process starts.
58-
Default numInputPartitions is 2.
59-
If the number of input element is over 3, that DStream use batach deserializer.
60-
If not, that DStream use unbatch deserializer.
61-
62-
All tests input should have list of lists(3 lists are default). This list represents stream.
63-
Every batch interval, the first object of list are chosen to make DStream.
64-
e.g The first list in the list is input of the first batch.
65-
Please see the BasicTestSuits in Scala which is close to this implementation.
66-
"""
67-
def setUp(self):
68-
PySparkStreamingTestCase.setUp(self)
69-
self.timeout = 10 # seconds
70-
self.numInputPartitions = 2
69+
start_time = time.time()
70+
# Loop until get the expected the number of the result from the stream.
71+
while True:
72+
current_time = time.time()
73+
# Check time out.
74+
if (current_time - start_time) > self.timeout:
75+
break
76+
# StreamingContext.awaitTermination is not used to wait because
77+
# if py4j server is called every 50 milliseconds, it gets an error.
78+
time.sleep(0.05)
79+
# Check if the output is the same length of expected output.
80+
if len(expected) == len(result):
81+
break
82+
if sort:
83+
self._sort_result_based_on_key(result)
84+
self._sort_result_based_on_key(expected)
85+
self.assertEqual(expected, result)
7186

87+
def _sort_result_based_on_key(self, outputs):
88+
"""Sort the list based on first value."""
89+
for output in outputs:
90+
output.sort(key=lambda x: x[0])
91+
92+
93+
class TestBasicOperations(PySparkStreamingTestCase):
7294
def test_map(self):
7395
"""Basic operation test for DStream.map."""
7496
input = [range(1, 5), range(5, 9), range(9, 13)]
@@ -239,54 +261,41 @@ def test_union(self):
239261
break
240262
self.assertEqual(expected, result)
241263

242-
def _sort_result_based_on_key(self, outputs):
243-
"""Sort the list base onf first value."""
244-
for output in outputs:
245-
output.sort(key=lambda x: x[0])
246264

247-
def _test_func(self, input, func, expected, numSlices=None, sort=False):
248-
"""
249-
Start stream and return the result.
250-
@param input: dataset for the test. This should be list of lists.
251-
@param func: wrapped function. This function should return PythonDStream object.
252-
@param expected: expected output for this testcase.
253-
@param numSlices: the number of slices in the rdd in the dstream.
254-
"""
255-
# Generate input stream with user-defined input.
256-
numSlices = numSlices or self.numInputPartitions
257-
input_stream = self.ssc._makeStream(input, numSlices)
258-
# Apply test function to stream.
259-
stream = func(input_stream)
260-
result = stream.collect()
261-
self.ssc.start()
265+
class TestWindowFunctions(PySparkStreamingTestCase):
262266

263-
start_time = time.time()
264-
# Loop until get the expected the number of the result from the stream.
265-
while True:
266-
current_time = time.time()
267-
# Check time out.
268-
if (current_time - start_time) > self.timeout:
269-
break
270-
# StreamingContext.awaitTermination is not used to wait because
271-
# if py4j server is called every 50 milliseconds, it gets an error.
272-
time.sleep(0.05)
273-
# Check if the output is the same length of expected output.
274-
if len(expected) == len(result):
275-
break
276-
if sort:
277-
self._sort_result_based_on_key(result)
278-
self._sort_result_based_on_key(expected)
279-
self.assertEqual(expected, result)
267+
timeout = 15
280268

269+
def test_count_by_window(self):
270+
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
281271

282-
class TestStreamingContext(unittest.TestCase):
283-
"""
284-
Should we have conf property in SparkContext?
285-
@property
286-
def conf(self):
287-
return self._conf
272+
def func(dstream):
273+
return dstream.countByWindow(4, 1)
274+
275+
expected = [[1], [3], [6], [9], [12], [15], [11], [6]]
276+
self._test_func(input, func, expected)
277+
278+
def test_count_by_window_large(self):
279+
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
288280

289-
"""
281+
def func(dstream):
282+
return dstream.countByWindow(6, 1)
283+
284+
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
285+
self._test_func(input, func, expected)
286+
287+
def test_group_by_key_and_window(self):
288+
input = [[('a', i)] for i in range(5)]
289+
290+
def func(dstream):
291+
return dstream.groupByKeyAndWindow(4, 1).mapValues(list)
292+
293+
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
294+
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
295+
self._test_func(input, func, expected)
296+
297+
298+
class TestStreamingContext(unittest.TestCase):
290299
def setUp(self):
291300
self.sc = SparkContext(master="local[2]", appName=self.__class__.__name__)
292301
self.batachDuration = Seconds(1)

0 commit comments

Comments
 (0)