Skip to content

Commit 357ad5f

Browse files
JoelNiklausclefourrierNathanHB
authored
Speed up Bootstrapping Computation (#409)
* Added fix for heavy recomputation of sample level metrics. * Moved parallelization to where it is actually useful. --------- Co-authored-by: Clémentine Fourrier <[email protected]> Co-authored-by: Nathan Habib <[email protected]>
1 parent 41dfe18 commit 357ad5f

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

src/lighteval/metrics/stderr.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import math
2828
import random
29-
from typing import Callable
29+
from typing import Callable, Optional
3030

3131
import numpy as np
3232
from scipy.stats import bootstrap
@@ -45,18 +45,27 @@ def mean_stderr(arr):
4545

4646

4747
class _bootstrap_internal:
48-
def __init__(self, metric: Callable, number_draws: int):
49-
self.metric = metric
48+
def __init__(self, number_draws: int, metric: Optional[Callable] = None):
5049
self.number_draws = number_draws
50+
self.metric = metric
5151

5252
def __call__(self, cur_experiment):
5353
# Creates number_draws samplings (with replacement) of the population by iterating on a given seed
5454
population, seed = cur_experiment
5555
rnd = random.Random()
5656
rnd.seed(seed)
5757
samplings = []
58-
for _ in range(self.number_draws):
59-
samplings.append(self.metric(rnd.choices(population, k=len(population))))
58+
import multiprocessing as mp
59+
60+
with mp.Pool(mp.cpu_count()) as pool:
61+
samplings = pool.starmap(
62+
self.metric,
63+
tqdm(
64+
[(rnd.choices(population, k=len(population)),) for _ in range(self.number_draws)],
65+
total=self.number_draws,
66+
desc="Sampling bootstrap iterations",
67+
),
68+
)
6069
return samplings
6170

6271

@@ -65,28 +74,15 @@ def bootstrap_stderr(metric: Callable, population: list, number_experiments: int
6574
by sampling said population for number_experiments and recomputing the metric on the
6675
different samplings.
6776
"""
68-
import multiprocessing as mp
69-
70-
pool = mp.Pool(mp.cpu_count())
71-
7277
res = []
7378
number_draws = min(1000, number_experiments)
74-
# We change the seed every 1000 re-samplings
75-
# and do the experiment 1000 re-samplings at a time
7679
number_seeds = number_experiments // number_draws
7780

78-
hlog(f"Bootstrapping {metric.__name__}'s stderr.")
79-
for cur_bootstrap in tqdm(
80-
pool.imap(
81-
_bootstrap_internal(metric=metric, number_draws=number_draws),
82-
((population, seed) for seed in range(number_seeds)),
83-
),
84-
total=number_seeds,
85-
):
81+
hlog(f"Bootstrapping {metric.__name__}'s stderr with {number_seeds} seeds.")
82+
for seed in range(number_seeds):
8683
# sample w replacement
87-
res.extend(cur_bootstrap)
84+
res.extend(_bootstrap_internal(metric=metric, number_draws=number_draws)((population, seed)))
8885

89-
pool.close()
9086
return mean_stderr(res)
9187

9288

0 commit comments

Comments
 (0)