Skip to content

Commit

Permalink
947 vi coupling flows fail for num dim=3 (#950)
Browse files Browse the repository at this point in the history
* Fixing issue with affine coupling flows

* Adding cheap tests to test each default flow builder for minimal requirements. Automatically test all defaults.

* Format with black

* isort

* Updated test from other pull request (to also test all default flows, even if new ones are added)
  • Loading branch information
manuelgloeckler authored Feb 27, 2024
1 parent 116f99a commit 3aeb775
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
5 changes: 4 additions & 1 deletion sbi/samplers/vi/vi_pyro_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def init_affine_coupling(dim: int, device: str = "cpu", **kwargs):
nonlinearity = kwargs.pop("nonlinearity", nn.ReLU())
split_dim = kwargs.get("split_dim", dim // 2)
hidden_dims = kwargs.pop("hidden_dims", [5 * dim + 20, 5 * dim + 20])
arn = DenseNN(split_dim, hidden_dims, nonlinearity=nonlinearity).to(device)
params_dims = (dim - split_dim, dim - split_dim)
arn = DenseNN(split_dim, hidden_dims, params_dims, nonlinearity=nonlinearity).to(
device
)
return [split_dim, arn], {"log_scale_min_clip": -3.0}


Expand Down
37 changes: 34 additions & 3 deletions tests/vi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
from sbi.inference import SNLE, likelihood_estimator_based_potential
from sbi.inference.posteriors import VIPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi.vi_pyro_flows import get_default_flows, get_flow_builder
from sbi.simulators.linear_gaussian import true_posterior_linear_gaussian_mvn_prior
from sbi.utils import MultipleIndependent
from tests.test_utils import check_c2st

# Tests should be run for all default flows
FLOWS = get_default_flows()


class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
Expand Down Expand Up @@ -84,7 +88,7 @@ def allow_iid_x(self) -> bool:

@pytest.mark.slow
@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
@pytest.mark.parametrize("q", FLOWS)
def test_c2st_vi_flows_on_Gaussian(num_dim: int, q: str):
"""Test VI on Gaussian, comparing to ground truth target via c2st.
Expand Down Expand Up @@ -189,7 +193,7 @@ def allow_iid_x(self) -> bool:
check_c2st(samples, target_samples, alg="slice_np")


@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
@pytest.mark.parametrize("q", FLOWS)
def test_deepcopy_support(q: str):
"""Tests if the variational does support deepcopy.
Expand Down Expand Up @@ -233,7 +237,7 @@ def test_deepcopy_support(q: str):
posterior_copy = deepcopy(posterior)


@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
@pytest.mark.parametrize("q", FLOWS)
def test_pickle_support(q: str):
"""Tests if the VIPosterior can be saved and loaded via pickle.
Expand Down Expand Up @@ -380,3 +384,30 @@ def simulator(theta):
sample_shape=(10,),
show_progress_bars=False,
)


@pytest.mark.parametrize("num_dim", (1, 2, 3, 4, 5, 10, 25, 33))
@pytest.mark.parametrize("q", FLOWS)
def test_vi_flow_builders(num_dim: int, q: str):
"""Test if the flow builder build the flows correctly, such that at least sampling and log_prob works."""

try:
q = get_flow_builder(q)(
(num_dim,), torch.distributions.transforms.identity_transform
)
except AssertionError:
# If the flow is not defined for the dimensionality, we pass the test
return

# Without sample_shape

sample = q.sample()
assert sample.shape == (num_dim,), "The sample shape is not as expected"
log_prob = q.log_prob(sample)
assert log_prob.shape == (), "The log_prob shape is not as expected"

# With sample_shape
sample_batch = q.sample((10,))
assert sample_batch.shape == (10, num_dim), "The sample shape is not as expected"
log_prob_batch = q.log_prob(sample_batch)
assert log_prob_batch.shape == (10,), "The log_prob shape is not as expected"

0 comments on commit 3aeb775

Please sign in to comment.