@@ -366,35 +366,37 @@ def takeSample(self, withReplacement, num, seed=None):
366366 [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
367367 """
368368
369- #TODO remove
370- logging .basicConfig (level = logging .INFO )
371369 numStDev = 10.0
372370 initialCount = self .count ()
373371
374372 if num < 0 :
375373 raise ValueError
376374
377- if initialCount == 0 :
375+ if initialCount == 0 or num == 0 :
378376 return list ()
379377
378+ rand = Random (seed )
380379 if (not withReplacement ) and num > initialCount :
381- raise ValueError
380+ # shuffle current RDD and return
381+ samples = self .collect ()
382+ fraction = float (num ) / initialCount
383+ num = initialCount
384+ else :
385+ maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
386+ if num > maxSampleSize :
387+ raise ValueError
382388
383- maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
384- if num > maxSampleSize :
385- raise ValueError
389+ fraction = self ._computeFractionForSampleSize (num , initialCount , withReplacement )
386390
387- fraction = self ._computeFractionForSampleSize (num , initialCount , withReplacement )
388-
389- samples = self .sample (withReplacement , fraction , seed ).collect ()
391+ samples = self .sample (withReplacement , fraction , seed ).collect ()
390392
391- # If the first sample didn't turn out large enough, keep trying to take samples;
392- # this shouldn't happen often because we use a big multiplier for their initial size.
393- # See: scala/spark/RDD.scala
394- rand = Random ( seed )
395- while len ( samples ) < num :
396- #TODO add log warning for when more than one iteration was run
397- samples = self .sample (withReplacement , fraction , rand . randint ( 0 , sys . maxint ) ).collect ()
393+ # If the first sample didn't turn out large enough, keep trying to take samples;
394+ # this shouldn't happen often because we use a big multiplier for their initial size.
395+ # See: scala/spark/RDD.scala
396+ while len ( samples ) < num :
397+ #TODO add log warning for when more than one iteration was run
398+ seed = rand . randint ( 0 , sys . maxint )
399+ samples = self .sample (withReplacement , fraction , seed ).collect ()
398400
399401 sampler = RDDSampler (withReplacement , fraction , rand .randint (0 , sys .maxint ))
400402 sampler .shuffle (samples )
0 commit comments