From 6d0629efe4e09aabee166d2aa23b205a1d671225 Mon Sep 17 00:00:00 2001 From: jan-matthis Date: Wed, 27 Sep 2023 21:38:52 +0200 Subject: [PATCH] Fix spurious seeding. --- sbi/inference/base.py | 9 ++++++++- tests/linearGaussian_snpe_test.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index cc3ca40a1..2c284c8d3 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -459,6 +459,7 @@ def simulate_for_sbi( num_simulations: int, num_workers: int = 1, simulation_batch_size: int = 1, + seed: Optional[int] = None, show_progress_bar: bool = True, ) -> Tuple[Tensor, Tensor]: r"""Returns ($\theta, x$) pairs obtained from sampling the proposal and simulating. @@ -481,6 +482,7 @@ def simulate_for_sbi( maps to data x at once. If None, we simulate all parameter sets at the same time. If >= 1, the simulator has to process data of shape (simulation_batch_size, parameter_dimension). + seed: Seed for reproducibility. show_progress_bar: Whether to show a progress bar for simulating. This will not affect whether there will be a progressbar while drawing samples from the proposal. @@ -491,7 +493,12 @@ def simulate_for_sbi( theta = proposal.sample((num_simulations,)) x = simulate_in_batches( - simulator, theta, simulation_batch_size, num_workers, show_progress_bar + simulator=simulator, + theta=theta, + sim_batch_size=simulation_batch_size, + num_workers=num_workers, + seed=seed, + show_progress_bars=show_progress_bar, ) return theta, x diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 81a67d9ad..1fe2eb61e 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -59,8 +59,8 @@ def test_c2st_snpe_on_linearGaussian( """Test whether SNPE infers well a simple example with available ground truth.""" x_o = zeros(num_trials, num_dim) - num_samples = 1000 - num_simulations = 2600 + num_samples = 5000 + num_simulations = 5000 # likelihood_mean will be likelihood_shift+theta likelihood_shift = -1.0 * ones(num_dim)