2222from pyspark .storagelevel import StorageLevel
2323from pyspark .streaming .util import rddToFileName , RDDFunction , RDDFunction2
2424from pyspark .rdd import portable_hash
25- from pyspark .streaming .duration import Seconds
26-
25+ from pyspark .streaming .duration import Duration , Seconds
26+ from pyspark . resultiterable import ResultIterable
2727
2828__all__ = ["DStream" ]
2929
@@ -299,13 +299,17 @@ def get_output(rdd, time):
299299 return result
300300
301301 def transform (self , func ):
302- return TransformedRDD (self , lambda a , t : func (a ), True )
302+ return TransformedDStream (self , lambda a , t : func (a ), True )
303303
304304 def transformWithTime (self , func ):
305- return TransformedRDD (self , func , False )
305+ return TransformedDStream (self , func , False )
306306
307307 def transformWith (self , func , other , keepSerializer = False ):
308- return Transformed2RDD (self , lambda a , b , t : func (a , b ), other , keepSerializer )
308+ jfunc = RDDFunction2 (self .ctx , func , self ._jrdd_deserializer )
309+ dstream = self .ctx ._jvm .PythonTransformed2DStream (self ._jdstream .dstream (),
310+ other ._jdstream .dstream (), jfunc )
311+ jrdd_serializer = self ._jrdd_deserializer if keepSerializer else self .ctx .serializer
312+ return DStream (dstream .asJavaDStream (), self ._ssc , jrdd_serializer )
309313
310314 def repartitions (self , numPartitions ):
311315 return self .transform (lambda rdd : rdd .repartition (numPartitions ))
@@ -336,28 +340,60 @@ def window(self, windowDuration, slideDuration=None):
336340 s = Seconds (slideDuration )
337341 return DStream (self ._jdstream .window (d , s ), self ._ssc , self ._jrdd_deserializer )
338342
339- def reduceByWindow (self , reduceFunc , inReduceFunc , windowDuration , slideDuration ):
340- pass
341-
342- def countByWindow (self , window , slide ):
343- pass
344-
345- def countByValueAndWindow (self , window , slide , numPartitions = None ):
346- pass
347-
348- def groupByKeyAndWindow (self , window , slide , numPartitions = None ):
349- pass
350-
351- def reduceByKeyAndWindow (self , reduceFunc , inReduceFunc , window , slide , numPartitions = None ):
352- pass
343+ def reduceByWindow (self , reduceFunc , invReduceFunc , windowDuration , slideDuration ):
344+ keyed = self .map (lambda x : (1 , x ))
345+ reduced = keyed .reduceByKeyAndWindow (reduceFunc , invReduceFunc ,
346+ windowDuration , slideDuration , 1 )
347+ return reduced .map (lambda (k , v ): v )
348+
349+ def countByWindow (self , windowDuration , slideDuration ):
350+ return self .map (lambda x : 1 ).reduceByWindow (operator .add , operator .sub ,
351+ windowDuration , slideDuration )
352+
353+ def countByValueAndWindow (self , windowDuration , slideDuration , numPartitions = None ):
354+ keyed = self .map (lambda x : (x , 1 ))
355+ counted = keyed .reduceByKeyAndWindow (lambda a , b : a + b , lambda a , b : a - b ,
356+ windowDuration , slideDuration , numPartitions )
357+ return counted .filter (lambda (k , v ): v > 0 ).count ()
358+
359+ def groupByKeyAndWindow (self , windowDuration , slideDuration , numPartitions = None ):
360+ ls = self .mapValues (lambda x : [x ])
361+ grouped = ls .reduceByKeyAndWindow (lambda a , b : a .extend (b ) or a , lambda a , b : a [len (b ):],
362+ windowDuration , slideDuration , numPartitions )
363+ return grouped .mapValues (ResultIterable )
364+
365+ def reduceByKeyAndWindow (self , func , invFunc ,
366+ windowDuration , slideDuration , numPartitions = None ):
367+ reduced = self .reduceByKey (func )
368+
369+ def reduceFunc (a , t ):
370+ return a .reduceByKey (func , numPartitions )
371+
372+ def invReduceFunc (a , b , t ):
373+ b = b .reduceByKey (func , numPartitions )
374+ joined = a .leftOuterJoin (b , numPartitions )
375+ return joined .mapValues (lambda (v1 , v2 ): invFunc (v1 , v2 ) if v2 is not None else v1 )
376+
377+ if not isinstance (windowDuration , Duration ):
378+ windowDuration = Seconds (windowDuration )
379+ if not isinstance (slideDuration , Duration ):
380+ slideDuration = Seconds (slideDuration )
381+ serializer = reduced ._jrdd_deserializer
382+ jreduceFunc = RDDFunction (self .ctx , reduceFunc , reduced ._jrdd_deserializer )
383+ jinvReduceFunc = RDDFunction2 (self .ctx , invReduceFunc , reduced ._jrdd_deserializer )
384+ dstream = self .ctx ._jvm .PythonReducedWindowedDStream (reduced ._jdstream .dstream (),
385+ jreduceFunc , jinvReduceFunc ,
386+ windowDuration ._jduration ,
387+ slideDuration ._jduration )
388+ return DStream (dstream .asJavaDStream (), self ._ssc , serializer )
353389
354390 def updateStateByKey (self , updateFunc ):
355391 # FIXME: convert updateFunc to java JFunction2
356392 jFunc = updateFunc
357393 return self ._jdstream .updateStateByKey (jFunc )
358394
359395
360- class TransformedRDD (DStream ):
396+ class TransformedDStream (DStream ):
361397 def __init__ (self , prev , func , reuse = False ):
362398 ssc = prev ._ssc
363399 self ._ssc = ssc
@@ -366,7 +402,8 @@ def __init__(self, prev, func, reuse=False):
366402 self .is_cached = False
367403 self .is_checkpointed = False
368404
369- if isinstance (prev , TransformedRDD ) and not prev .is_cached and not prev .is_checkpointed :
405+ if (isinstance (prev , TransformedDStream ) and
406+ not prev .is_cached and not prev .is_checkpointed ):
370407 prev_func = prev .func
371408 old_func = func
372409 func = lambda rdd , t : old_func (prev_func (rdd , t ), t )
@@ -388,13 +425,3 @@ def _jdstream(self):
388425 jfunc , self .reuse ).asJavaDStream ()
389426 self ._jdstream_val = jdstream
390427 return jdstream
391-
392-
393- class Transformed2RDD (DStream ):
394- def __init__ (self , prev , func , other , keepSerializer = False ):
395- ssc = prev ._ssc
396- jfunc = RDDFunction2 (ssc ._sc , func , prev ._jrdd_deserializer )
397- jdstream = ssc ._jvm .PythonTransformed2DStream (prev ._jdstream .dstream (),
398- other ._jdstream .dstream (), jfunc )
399- jrdd_serializer = prev ._jrdd_deserializer if keepSerializer else ssc ._sc .serializer
400- DStream .__init__ (self , jdstream .asJavaDStream (), ssc , jrdd_serializer )
0 commit comments