diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 37c8366838..dfc7eae382 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -11,6 +11,7 @@ ### Maintenance - Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508) +- Parallelization of population steppers (`DEMetropolis`) is now set via the `cores` argument. ([#3559](https://github.com/pymc-devs/pymc3/pull/3559)) ## PyMC3 3.7 (May 29 2019) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 8fb17df4f5..1c9b515b13 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -452,7 +452,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N if has_population_samplers: _log.info('Population sampling ({} chains)'.format(chains)) _print_step_hierarchy(step) - trace = _sample_population(**sample_args) + trace = _sample_population(**sample_args, parallelize=cores > 1) else: _log.info('Sequential sampling ({} chains in 1 job)'.format(chains)) _print_step_hierarchy(step) @@ -689,7 +689,7 @@ def __init__(self, steppers, parallelize): if parallelize: try: # configure a child process for each stepper - _log.info('Attempting to parallelize chains.') + _log.info('Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`.') import multiprocessing for c, stepper in enumerate(tqdm(steppers)): slave_end, master_end = multiprocessing.Pipe() @@ -714,7 +714,7 @@ def __init__(self, steppers, parallelize): _log.debug('Error was: ', exec_info=True) else: _log.info('Chains are not parallelized. You can enable this by passing ' - 'pm.sample(parallelize=True).') + '`pm.sample(cores=n)`, where n > 1.') return super().__init__() def __enter__(self): diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index 32547b1d88..d67c0ef1d2 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -915,12 +915,25 @@ def test_checks_population_size(self): trace = sample(draws=100, chains=4, step=step) pass + def test_nonparallelized_chains_are_random(self): + with Model() as model: + x = Normal("x", 0, 1) + for stepper in TestPopulationSamplers.steppers: + step = stepper() + trace = sample(chains=4, cores=1, draws=20, tune=0, step=DEMetropolis()) + samples = np.array(trace.get_values("x", combine=False))[:, 5] + + assert len(set(samples)) == 4, "Parallelized {} " "chains are identical.".format( + stepper + ) + pass + def test_parallelized_chains_are_random(self): with Model() as model: x = Normal("x", 0, 1) for stepper in TestPopulationSamplers.steppers: step = stepper() - trace = sample(chains=4, draws=20, tune=0, step=DEMetropolis()) + trace = sample(chains=4, cores=4, draws=20, tune=0, step=DEMetropolis()) samples = np.array(trace.get_values("x", combine=False))[:, 5] assert len(set(samples)) == 4, "Parallelized {} " "chains are identical.".format(