2222class RDDSamplerBase (object ):
2323
2424 def __init__ (self , withReplacement , seed = None ):
25- try :
26- import numpy
27- self ._use_numpy = True
28- except ImportError :
29- print >> sys .stderr , (
30- "NumPy does not appear to be installed. "
31- "Falling back to default random generator for sampling." )
32- self ._use_numpy = False
33-
34- self ._seed = seed if seed is not None else random .randint (0 , 2 ** 32 - 1 )
25+ self ._seed = seed if seed is not None else random .randint (0 , sys .maxint )
3526 self ._withReplacement = withReplacement
3627 self ._random = None
37- self ._split = None
38- self ._rand_initialized = False
3928
4029 def initRandomGenerator (self , split ):
41- if self ._use_numpy :
42- import numpy
43- self ._random = numpy .random .RandomState (self ._seed ^ split )
44- else :
45- self ._random = random .Random (self ._seed ^ split )
30+ self ._random = random .Random (self ._seed ^ split )
4631
4732 # mixing because the initial seeds are close to each other
4833 for _ in xrange (10 ):
4934 self ._random .randint (0 , 1 )
5035
51- self ._split = split
52- self ._rand_initialized = True
53-
54- def getUniformSample (self , split ):
55- if not self ._rand_initialized or split != self ._split :
56- self .initRandomGenerator (split )
57-
58- if self ._use_numpy :
59- return self ._random .random_sample ()
60- else :
61- return self ._random .uniform (0.0 , 1.0 )
62-
63- def getPoissonSample (self , split , mean ):
64- if not self ._rand_initialized or split != self ._split :
65- self .initRandomGenerator (split )
36+ def getUniformSample (self ):
37+ return self ._random .random ()
6638
67- if self ._use_numpy :
68- return self ._random .poisson (mean )
69- else :
70- # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
71- # drawing a sequence of numbers delta_j ~ Exp(mean)
72- num_arrivals = 1
73- cur_time = 0.0
39+ def getPoissonSample (self , mean ):
40+ # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
41+ # drawing a sequence of numbers delta_j ~ Exp(mean)
42+ num_arrivals = 0
43+ cur_time = self ._random .expovariate (mean )
7444
45+ while cur_time < 1.0 :
7546 cur_time += self ._random .expovariate (mean )
47+ num_arrivals += 1
7648
77- if cur_time > 1.0 :
78- return 0
49+ return num_arrivals
7950
80- while (cur_time <= 1.0 ):
81- cur_time += self ._random .expovariate (mean )
82- num_arrivals += 1
83-
84- return (num_arrivals - 1 )
85-
86- def shuffle (self , vals ):
87- if self ._random is None :
88- self .initRandomGenerator (0 ) # this should only ever called on the master so
89- # the split does not matter
90-
91- if self ._use_numpy :
92- self ._random .shuffle (vals )
93- else :
94- self ._random .shuffle (vals , self ._random .random )
51+ def func (self , split , iterator ):
52+ raise NotImplementedError
9553
9654
9755class RDDSampler (RDDSamplerBase ):
@@ -101,31 +59,32 @@ def __init__(self, withReplacement, fraction, seed=None):
10159 self ._fraction = fraction
10260
10361 def func (self , split , iterator ):
62+ self .initRandomGenerator (split )
10463 if self ._withReplacement :
10564 for obj in iterator :
10665 # For large datasets, the expected number of occurrences of each element in
10766 # a sample with replacement is Poisson(frac). We use that to get a count for
10867 # each element.
109- count = self .getPoissonSample (split , mean = self ._fraction )
68+ count = self .getPoissonSample (self ._fraction )
11069 for _ in range (0 , count ):
11170 yield obj
11271 else :
11372 for obj in iterator :
114- if self .getUniformSample (split ) <= self ._fraction :
73+ if self .getUniformSample () <= self ._fraction :
11574 yield obj
11675
11776
11877class RDDRangeSampler (RDDSamplerBase ):
11978
12079 def __init__ (self , lowerBound , upperBound , seed = None ):
12180 RDDSamplerBase .__init__ (self , False , seed )
122- self ._use_numpy = False # no performance gain from numpy
12381 self ._lowerBound = lowerBound
12482 self ._upperBound = upperBound
12583
12684 def func (self , split , iterator ):
85+ self .initRandomGenerator (split )
12786 for obj in iterator :
128- if self ._lowerBound <= self .getUniformSample (split ) < self ._upperBound :
87+ if self ._lowerBound <= self .getUniformSample () < self ._upperBound :
12988 yield obj
13089
13190
@@ -136,15 +95,16 @@ def __init__(self, withReplacement, fractions, seed=None):
13695 self ._fractions = fractions
13796
13897 def func (self , split , iterator ):
98+ self .initRandomGenerator (split )
13999 if self ._withReplacement :
140100 for key , val in iterator :
141101 # For large datasets, the expected number of occurrences of each element in
142102 # a sample with replacement is Poisson(frac). We use that to get a count for
143103 # each element.
144- count = self .getPoissonSample (split , mean = self ._fractions [key ])
104+ count = self .getPoissonSample (self ._fractions [key ])
145105 for _ in range (0 , count ):
146106 yield key , val
147107 else :
148108 for key , val in iterator :
149- if self .getUniformSample (split ) <= self ._fractions [key ]:
109+ if self .getUniformSample () <= self ._fractions [key ]:
150110 yield key , val
0 commit comments