26
26
27
27
import math
28
28
import random
29
- from typing import Callable
29
+ from typing import Callable , Optional
30
30
31
31
import numpy as np
32
32
from scipy .stats import bootstrap
@@ -45,18 +45,27 @@ def mean_stderr(arr):
45
45
46
46
47
47
class _bootstrap_internal :
48
- def __init__ (self , metric : Callable , number_draws : int ):
49
- self .metric = metric
48
+ def __init__ (self , number_draws : int , metric : Optional [Callable ] = None ):
50
49
self .number_draws = number_draws
50
+ self .metric = metric
51
51
52
52
def __call__ (self , cur_experiment ):
53
53
# Creates number_draws samplings (with replacement) of the population by iterating on a given seed
54
54
population , seed = cur_experiment
55
55
rnd = random .Random ()
56
56
rnd .seed (seed )
57
57
samplings = []
58
- for _ in range (self .number_draws ):
59
- samplings .append (self .metric (rnd .choices (population , k = len (population ))))
58
+ import multiprocessing as mp
59
+
60
+ with mp .Pool (mp .cpu_count ()) as pool :
61
+ samplings = pool .starmap (
62
+ self .metric ,
63
+ tqdm (
64
+ [(rnd .choices (population , k = len (population )),) for _ in range (self .number_draws )],
65
+ total = self .number_draws ,
66
+ desc = "Sampling bootstrap iterations" ,
67
+ ),
68
+ )
60
69
return samplings
61
70
62
71
@@ -65,28 +74,15 @@ def bootstrap_stderr(metric: Callable, population: list, number_experiments: int
65
74
by sampling said population for number_experiments and recomputing the metric on the
66
75
different samplings.
67
76
"""
68
- import multiprocessing as mp
69
-
70
- pool = mp .Pool (mp .cpu_count ())
71
-
72
77
res = []
73
78
number_draws = min (1000 , number_experiments )
74
- # We change the seed every 1000 re-samplings
75
- # and do the experiment 1000 re-samplings at a time
76
79
number_seeds = number_experiments // number_draws
77
80
78
- hlog (f"Bootstrapping { metric .__name__ } 's stderr." )
79
- for cur_bootstrap in tqdm (
80
- pool .imap (
81
- _bootstrap_internal (metric = metric , number_draws = number_draws ),
82
- ((population , seed ) for seed in range (number_seeds )),
83
- ),
84
- total = number_seeds ,
85
- ):
81
+ hlog (f"Bootstrapping { metric .__name__ } 's stderr with { number_seeds } seeds." )
82
+ for seed in range (number_seeds ):
86
83
# sample w replacement
87
- res .extend (cur_bootstrap )
84
+ res .extend (_bootstrap_internal ( metric = metric , number_draws = number_draws )(( population , seed )) )
88
85
89
- pool .close ()
90
86
return mean_stderr (res )
91
87
92
88
0 commit comments