@@ -552,14 +552,18 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None
552552
553553 def reduceFunc (t , a , b ):
554554 b = b .reduceByKey (func , numPartitions )
555- r = a .union (b ).reduceByKey (func , numPartitions ) if a else b
555+ # use the average of number of partitions, or it will keep increasing
556+ partitions = numPartitions or (a .getNumPartitions () + b .getNumPartitions ())/ 2
557+ r = a .union (b ).reduceByKey (func , partitions ) if a else b
556558 if filterFunc :
557559 r = r .filter (filterFunc )
558560 return r
559561
560562 def invReduceFunc (t , a , b ):
561563 b = b .reduceByKey (func , numPartitions )
562- joined = a .leftOuterJoin (b , numPartitions )
564+ # use the average of number of partitions, or it will keep increasing
565+ partitions = numPartitions or (a .getNumPartitions () + b .getNumPartitions ())/ 2
566+ joined = a .leftOuterJoin (b , partitions )
563567 return joined .mapValues (lambda (v1 , v2 ): invFunc (v1 , v2 ) if v2 is not None else v1 )
564568
565569 jreduceFunc = RDDFunction (self .ctx , reduceFunc , reduced ._jrdd_deserializer )
@@ -587,7 +591,9 @@ def reduceFunc(t, a, b):
587591 if a is None :
588592 g = b .groupByKey (numPartitions ).map (lambda (k , vs ): (k , list (vs ), None ))
589593 else :
590- g = a .cogroup (b , numPartitions )
594+ # use the average of number of partitions, or it will keep increasing
595+ partitions = numPartitions or (a .getNumPartitions () + b .getNumPartitions ())/ 2
596+ g = a .cogroup (b , partitions )
591597 g = g .map (lambda (k , (va , vb )): (k , list (vb ), list (va )[0 ] if len (va ) else None ))
592598 state = g .mapPartitions (lambda x : updateFunc (x ))
593599 return state .filter (lambda (k , v ): v is not None )
0 commit comments