Skip to content

Commit

Permalink
Fix spurious seeding.
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-matthis committed Oct 12, 2023
1 parent 6c4fa7a commit 6d0629e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 8 additions & 1 deletion sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def simulate_for_sbi(
num_simulations: int,
num_workers: int = 1,
simulation_batch_size: int = 1,
seed: Optional[int] = None,
show_progress_bar: bool = True,
) -> Tuple[Tensor, Tensor]:
r"""Returns ($\theta, x$) pairs obtained from sampling the proposal and simulating.
Expand All @@ -481,6 +482,7 @@ def simulate_for_sbi(
maps to data x at once. If None, we simulate all parameter sets at the
same time. If >= 1, the simulator has to process data of shape
(simulation_batch_size, parameter_dimension).
seed: Seed for reproducibility.
show_progress_bar: Whether to show a progress bar for simulating. This will not
affect whether there will be a progressbar while drawing samples from the
proposal.
Expand All @@ -491,7 +493,12 @@ def simulate_for_sbi(
theta = proposal.sample((num_simulations,))

x = simulate_in_batches(
simulator, theta, simulation_batch_size, num_workers, show_progress_bar
simulator=simulator,
theta=theta,
sim_batch_size=simulation_batch_size,
num_workers=num_workers,
seed=seed,
show_progress_bars=show_progress_bar,
)

return theta, x
Expand Down
4 changes: 2 additions & 2 deletions tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def test_c2st_snpe_on_linearGaussian(
"""Test whether SNPE infers well a simple example with available ground truth."""

x_o = zeros(num_trials, num_dim)
num_samples = 1000
num_simulations = 2600
num_samples = 5000
num_simulations = 5000

# likelihood_mean will be likelihood_shift+theta
likelihood_shift = -1.0 * ones(num_dim)
Expand Down

0 comments on commit 6d0629e

Please sign in to comment.