From 6f242f82974e0bec8fdf876585617515810a526b Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 22 Feb 2024 20:40:33 +0100 Subject: [PATCH 1/3] blub --- .github/workflows/ci.yaml | 4 + .github/workflows/release.yaml | 4 +- .pre-commit-config.yaml | 8 +- .readthedocs.yaml | 17 +++ examples/bivariate_gaussian_snasss.py | 18 +-- pyproject.toml | 4 + sbijax/__init__.py | 12 +- sbijax/{abc => _src}/__init__.py | 0 sbijax/{ => _src}/_sbi_base.py | 22 +-- sbijax/{ => _src}/_sne_base.py | 117 +++++++-------- sbijax/_src/abc/__init__.py | 0 sbijax/{ => _src}/abc/rejection_abc.py | 56 +++---- sbijax/{ => _src}/abc/smc_abc.py | 53 ++++--- sbijax/{ => _src}/generator.py | 31 +++- sbijax/_src/mcmc/__init__.py | 3 + sbijax/{ => _src}/mcmc/irmh.py | 29 ++-- sbijax/{ => _src}/mcmc/mala.py | 32 ++-- sbijax/{ => _src}/mcmc/nuts.py | 34 ++--- sbijax/{ => _src}/mcmc/rmh.py | 28 ++-- sbijax/{ => _src}/mcmc/sample.py | 14 +- sbijax/{ => _src}/mcmc/slice.py | 29 ++-- sbijax/_src/nn/__init__.py | 0 sbijax/_src/nn/make_snass_networks.py | 90 +++++++++++ sbijax/_src/nn/snass_net.py | 70 +++++++++ sbijax/_src/nn/snasss_net.py | 73 +++++++++ sbijax/{ => _src}/snass.py | 19 ++- sbijax/{ => _src}/snasss.py | 21 ++- sbijax/{ => _src}/snl.py | 167 +++++++++------------ sbijax/{ => _src}/snl_test.py | 2 +- sbijax/{ => _src}/snp.py | 90 +++++------ sbijax/{ => _src}/snp_test.py | 2 +- sbijax/_src/util/__init__.py | 0 sbijax/{nn => _src/util}/early_stopping.py | 17 ++- sbijax/mcmc/__init__.py | 3 - sbijax/nn/__init__.py | 3 + sbijax/nn/snass_net.py | 50 ------ sbijax/nn/snasss_net.py | 57 ------- 37 files changed, 634 insertions(+), 545 deletions(-) create mode 100644 .readthedocs.yaml rename sbijax/{abc => _src}/__init__.py (100%) rename sbijax/{ => _src}/_sbi_base.py (56%) rename sbijax/{ => _src}/_sne_base.py (62%) create mode 100644 sbijax/_src/abc/__init__.py rename sbijax/{ => _src}/abc/rejection_abc.py (61%) rename sbijax/{ => _src}/abc/smc_abc.py (88%) rename sbijax/{ => _src}/generator.py (74%) create mode 100644 sbijax/_src/mcmc/__init__.py rename sbijax/{ => _src}/mcmc/irmh.py (76%) rename sbijax/{ => _src}/mcmc/mala.py (75%) rename sbijax/{ => _src}/mcmc/nuts.py (77%) rename sbijax/{ => _src}/mcmc/rmh.py (76%) rename sbijax/{ => _src}/mcmc/sample.py (75%) rename sbijax/{ => _src}/mcmc/slice.py (70%) create mode 100644 sbijax/_src/nn/__init__.py create mode 100644 sbijax/_src/nn/make_snass_networks.py create mode 100644 sbijax/_src/nn/snass_net.py create mode 100644 sbijax/_src/nn/snasss_net.py rename sbijax/{ => _src}/snass.py (95%) rename sbijax/{ => _src}/snasss.py (95%) rename sbijax/{ => _src}/snl.py (67%) rename sbijax/{ => _src}/snl_test.py (99%) rename sbijax/{ => _src}/snp.py (83%) rename sbijax/{ => _src}/snp_test.py (98%) create mode 100644 sbijax/_src/util/__init__.py rename sbijax/{nn => _src/util}/early_stopping.py (75%) delete mode 100644 sbijax/mcmc/__init__.py delete mode 100644 sbijax/nn/snass_net.py delete mode 100644 sbijax/nn/snasss_net.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index bb061bd..d018580 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -77,3 +77,7 @@ jobs: - name: Run tests run: | hatch run test:test + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 7dd45f1..d1b137c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -12,9 +12,9 @@ jobs: matrix: python-version: [3.11] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Install pypa/build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3b0bd04..bfa8a58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,8 +60,8 @@ repos: args: ["--ignore-missing-imports"] files: "(sbijax|examples)" -- repo: https://github.com/jorisroovers/gitlint - rev: v0.18.0 +- repo: https://github.com/pycqa/pydocstyle + rev: 6.1.1 hooks: - - id: gitlint - - id: gitlint-ci + - id: pydocstyle + additional_dependencies: ["toml"] diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..3332335 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +sphinx: + builder: html + configuration: docs/conf.py + fail_on_warning: false + +python: + install: + - method: pip + path: . + - requirements: docs/requirements.txt diff --git a/examples/bivariate_gaussian_snasss.py b/examples/bivariate_gaussian_snasss.py index 821bd49..3e335d4 100644 --- a/examples/bivariate_gaussian_snasss.py +++ b/examples/bivariate_gaussian_snasss.py @@ -21,7 +21,7 @@ from surjectors.util import unstack from sbijax import SNASSS -from sbijax.nn.snasss_net import SNASSSNet +from sbijax.nn import make_snasss_net W = jr.normal(jr.PRNGKey(0), (2, 10)) @@ -74,23 +74,17 @@ def _flow(method, **kwargs): return td -def make_critic(dim): - @hk.without_apply_rng - @hk.transform - def _net(method, **kwargs): - net = SNASSSNet([64, 64, dim], [64, 64, 1], [64, 64, 1]) - return net(method, **kwargs) - - return _net - - def run(): y_observed = jnp.array([[2.0, -2.0]]) @ W prior_simulator_fn, prior_logdensity_fn = prior_model_fns() fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn - estim = SNASSS(fns, make_model(2), make_critic(2)) + estim = SNASSS( + fns, + make_model(2), + make_snasss_net([64, 64, 2], [64, 64, 1], [64, 64, 1]), + ) optimizer = optax.adam(1e-3) data, params = None, {} diff --git a/pyproject.toml b/pyproject.toml index 53b9993..1c476d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,3 +96,7 @@ invalid-name,missing-module-docstring,R0801,E0633 [tool.bandit] skips = ["B101"] + +[tool.pydocstyle] +convention= 'google' +match = '^sbijax/.*/((?!_test).)*\.py' diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 589dc6a..6357a09 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -5,9 +5,9 @@ __version__ = "0.1.5" -from sbijax.abc.rejection_abc import RejectionABC -from sbijax.abc.smc_abc import SMCABC -from sbijax.snass import SNASS -from sbijax.snasss import SNASSS -from sbijax.snl import SNL -from sbijax.snp import SNP +from sbijax._src.abc.rejection_abc import RejectionABC +from sbijax._src.abc.smc_abc import SMCABC +from sbijax._src.snass import SNASS +from sbijax._src.snasss import SNASSS +from sbijax._src.snl import SNL +from sbijax._src.snp import SNP diff --git a/sbijax/abc/__init__.py b/sbijax/_src/__init__.py similarity index 100% rename from sbijax/abc/__init__.py rename to sbijax/_src/__init__.py diff --git a/sbijax/_sbi_base.py b/sbijax/_src/_sbi_base.py similarity index 56% rename from sbijax/_sbi_base.py rename to sbijax/_src/_sbi_base.py index 496f919..4a1b6a8 100644 --- a/sbijax/_sbi_base.py +++ b/sbijax/_src/_sbi_base.py @@ -6,23 +6,23 @@ # pylint: disable=too-many-instance-attributes,unused-argument, # pylint: disable=too-few-public-methods class SBI(abc.ABC): - """ - SBI base class - """ + """SBI base class.""" def __init__(self, model_fns): + """Construct an SBI object. + + Args: + model_fns: tuple + """ self.prior_sampler_fn, self.prior_log_density_fn = model_fns[0] self.simulator_fn = model_fns[1] self._len_theta = len(self.prior_sampler_fn(seed=jr.PRNGKey(123))) @abc.abstractmethod - def sample_posterior(self, rng_key, **kwargs): - """ - Sample from the posterior distribution + def sample_posterior(self, rng_key: jr.KeyArray, **kwargs): + """Sample from the posterior distribution. - Parameters - ---------- - rng_key: jax.PRNGKey - a random key - kwargs: keyword arguments with sampler specific parameters + Args: + rng_key: a random key + kwargs: keyword arguments with sampler specific parameters """ diff --git a/sbijax/_sne_base.py b/sbijax/_src/_sne_base.py similarity index 62% rename from sbijax/_sne_base.py rename to sbijax/_src/_sne_base.py index b72af97..0fc377b 100644 --- a/sbijax/_sne_base.py +++ b/sbijax/_src/_sne_base.py @@ -4,19 +4,22 @@ from jax import numpy as jnp from jax import random as jr -from sbijax import generator -from sbijax._sbi_base import SBI -from sbijax.generator import named_dataset +from sbijax._src._sbi_base import SBI +from sbijax._src.generator import as_batch_iterators, named_dataset # pylint: disable=too-many-arguments,unused-argument # pylint: disable=too-many-function-args,arguments-differ class SNE(SBI, ABC): - """ - Sequential neural estimation - """ + """Sequential neural estimation base class.""" def __init__(self, model_fns, density_estimator): + """Construct an SNE object. + + Args: + model_fns: tuple + density_estimator: maf + """ super().__init__(model_fns) self.model = density_estimator self.n_total_simulations = 0 @@ -30,31 +33,19 @@ def simulate_data_and_possibly_append( n_simulations=1000, **kwargs, ): - """ - Simulate data from the posteriorand append it to an existing data set - (if provided) - - Parameters - ---------- - rng_key: jax.PRNGKey - a random key - params: pytree - a dictionary of neural network parameters - observable: jnp.ndarray - an observation - data: NamedTuple - existing data set - n_simulations: int - number of newly simulated data - kwargs: keyword arguments - dictionary of ey value pairs passed to `sample_posterior` - - Returns - ------- - NamedTuple: + """Simulate data from the prior or posterior and append. + + Args: + rng_key: a random key + params: a dictionary of neural network parameters + observable: an observation + data: existing data set + n_simulations: number of newly simulated data + kwargs: dictionary of ey value pairs passed to `sample_posterior` + + Returns: returns a NamedTuple of two axis, y and theta """ - observable = jnp.atleast_2d(observable) new_data, diagnostics = self.simulate_data( rng_key, @@ -78,31 +69,21 @@ def simulate_data( n_simulations=1000, **kwargs, ): + r"""Simulate data from the posterior or prior and append. + + Args: + rng_key: a random key + params:a dictionary of neural network parameters. If None, will + draw from prior. If parameters given, will draw from amortized + posterior using 'observable; + observable: an observation. Needs to be gfiven if posterior draws + are desired + n_simulations: number of newly simulated data + kwargs: dictionary of ey value pairs passed to `sample_posterior` + + Returns: + a NamedTuple of two axis, y and theta """ - Simulate data from the posterior or prior and append it to an - existing data set (if provided) - - Parameters - ---------- - rng_key: jax.PRNGKey - a random key - params: Optional[pytree] - a dictionary of neural network parameters. If None, will draw from - prior. If parameters given, will draw from amortized posterior - using 'observable; - observable: Optional[jnp.ndarray] - an observation. Needs to be gfiven if posterior draws are desired - n_simulations: int - number of newly simulated data - kwargs: keyword arguments - dictionary of ey value pairs passed to `sample_posterior` - - Returns - ------- - NamedTuple: - returns a NamedTuple of two axis, y and theta - """ - sample_key, rng_key = jr.split(rng_key) if params is None or len(params) == 0: diagnostics = None @@ -140,21 +121,15 @@ def simulate_data( @staticmethod def stack_data(data, also_data): - """ - Stack two data sets. + """Stack two data sets. - Parameters - ---------- - data: NamedTuple - one data set - also_data: : NamedTuple + Args: + data: one data set + also_data: another data set - Returns - ------- - NamedTuple: + Returns: returns the stack of the two data sets """ - if data is None: return also_data if also_data is None: @@ -166,8 +141,18 @@ def stack_data(data, also_data): def as_iterators( self, rng_key, data, batch_size, percentage_data_as_validation_set ): - """Convert the data set to an iterable for training""" - return generator.as_batch_iterators( + """Convert the data set to an iterable for training. + + Args: + rng_key: random key + data: tuple + batch_size: integer + percentage_data_as_validation_set: fraction + + Returns: + a batch iterator + """ + return as_batch_iterators( rng_key, data, batch_size, diff --git a/sbijax/_src/abc/__init__.py b/sbijax/_src/abc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sbijax/abc/rejection_abc.py b/sbijax/_src/abc/rejection_abc.py similarity index 61% rename from sbijax/abc/rejection_abc.py rename to sbijax/_src/abc/rejection_abc.py index f940022..d8f9506 100644 --- a/sbijax/abc/rejection_abc.py +++ b/sbijax/_src/abc/rejection_abc.py @@ -1,19 +1,33 @@ +from typing import Callable, Tuple + from jax import numpy as jnp from jax import random as jr -from sbijax._sbi_base import SBI +from sbijax._src._sbi_base import SBI # pylint: disable=too-many-instance-attributes,too-many-arguments -# pylint: disable=too-many-locals,too-few-public-methods +# pylint: disable=too-many-locals,too-few-public-methods, class RejectionABC(SBI): - """ - Sisson et al. - Handbook of approximate Bayesian computation + """Rejection approximate Bayesian computation. - Algorithm 4.1, "ABC Rejection Sampling Algorithm" + Implements algorithm~4.1 from [1]. + + References: + .. [1] Sisson, Scott A, et al. "Handbook of approximate Bayesian + computation". 2019 """ - def __init__(self, model_fns, summary_fn, kernel_fn): + def __init__( + self, model_fns: Tuple, summary_fn: Callable, kernel_fn: Callable + ): + """Constructs a RejectionABC object. + + Args: + model_fns: tuple + summary_fn: summary statistice function + kernel_fn: a kernel function to compute similarities + """ super().__init__(model_fns) self.kernel_fn = kernel_fn self.summary_fn = summary_fn @@ -29,31 +43,21 @@ def sample_posterior( h, **kwargs, ): - """ - Sample from the approximate posterior + r"""Sample from the approximate posterior. - Parameters - ---------- - rng_key: jax.PRNGKey - a random key - observable: jnp.Array - observation to condition on - n_samples: int - number of samples to draw for each parameter - n_simulations_per_theta: int - number of simulations for each paramter sample - K: double - normalisation parameter - h: double - kernel scale + Args: + rng_key: a random key + observable: observation to condition on + n_samples: number of samples to draw for each parameter + n_simulations_per_theta: number of simulations for each parameter + sample + K: normalisation parameter + h: kernel scale - Returns - ------- - chex.Array + Returns: an array of samples from the posterior distribution of dimension (n_samples \times p) """ - observable = jnp.atleast_2d(observable) thetas = None diff --git a/sbijax/abc/smc_abc.py b/sbijax/_src/abc/smc_abc.py similarity index 88% rename from sbijax/abc/smc_abc.py rename to sbijax/_src/abc/smc_abc.py index c3a7e03..265bee4 100644 --- a/sbijax/abc/smc_abc.py +++ b/sbijax/_src/abc/smc_abc.py @@ -9,19 +9,29 @@ from jax import random as jr from jax import scipy as jsp -from sbijax._sbi_base import SBI +from sbijax._src._sbi_base import SBI # pylint: disable=arguments-differ,too-many-function-args,too-many-locals # pylint: disable=too-few-public-methods class SMCABC(SBI): - """ - Sisson et al. - Handbook of approximate Bayesian computation + """Sequential Monte Carlo approximate Bayesian computation. + + Implements algorithm~4.8 from [1]. - Algorithm 4.8, "Algorithm 4.8: ABC Sequential Monte Carlo Algorithm" + References: + .. [1] Sisson, Scott A, et al. "Handbook of approximate Bayesian + computation". 2019 """ def __init__(self, model_fns, summary_fn, distance_fn): + """Construct a SMCABC object. + + Args: + model_fns: tuple + summary_fn: summary function + distance_fn: distance function + """ super().__init__(model_fns) self.summary_fn = summary_fn self.distance_fn = distance_fn @@ -40,27 +50,20 @@ def sample_posterior( ess_min, cov_scale=1.0, ): - """ - Sample from the approximate posterior - - Parameters - ---------- - n_rounds: int - max number of SMC rounds - n_particles: int - number of n_particles to draw for each parameter - n_simulations_per_theta: int - number of simulations for each paramter sample - eps_step: float - decay of initial epsilon per simulation round - ess_min: float - minimal effective sample size - cov_scale: float - scaling of the transition kernel covariance - - Returns - ------- - chex.Array + r"""Sample from the approximate posterior. + + Args: + n_rounds: max number of SMC rounds + observable: the observation to condition on + n_round: number of rounds of SMC + n_particles: number of n_particles to draw for each parameter + n_simulations_per_theta: number of simulations for each paramrter + sample + eps_step: decay of initial epsilon per simulation round + ess_min: minimal effective sample size + cov_scale: scaling of the transition kernel covariance + + Returns: an array of samples from the posterior distribution of dimension (n_samples \times p) """ diff --git a/sbijax/generator.py b/sbijax/_src/generator.py similarity index 74% rename from sbijax/generator.py rename to sbijax/_src/generator.py index ca2ab1b..e36ed29 100644 --- a/sbijax/generator.py +++ b/sbijax/_src/generator.py @@ -10,7 +10,10 @@ # pylint: disable=missing-class-docstring,too-few-public-methods class DataLoader: - def __init__(self, num_batches, idxs=None, get_batch=None, batches=None): + # noqa: D101 + def __init__( + self, num_batches, idxs=None, get_batch=None, batches=None + ): # noqa: D107 self.num_batches = num_batches self.idxs = idxs if idxs is not None: @@ -20,7 +23,7 @@ def __init__(self, num_batches, idxs=None, get_batch=None, batches=None): self.get_batch = get_batch self.batches = batches - def __call__(self, idx, idxs=None): + def __call__(self, idx, idxs=None): # noqa: D102 if self.batches is not None: return self.batches[idx] @@ -33,6 +36,19 @@ def __call__(self, idx, idxs=None): def as_batch_iterators( rng_key: chex.PRNGKey, data: named_dataset, batch_size, split, shuffle ): + """Create two data batch iterators from a data set. + + Args: + rng_key: random key + data: a named tuple containing all dat + batch_size: batch size + split: fraction of data to use for training data set. Rest is used + for validation data set. + shuffle: shuffle the data set or no + + Returns: + two iterators + """ n = data.y.shape[0] n_train = int(n * split) @@ -54,6 +70,17 @@ def as_batch_iterators( def as_batch_iterator( rng_key: chex.PRNGKey, data: named_dataset, batch_size, shuffle ): + """Create a data batch iterator from a data set. + + Args: + rng_key: random key + data: a named tuple containing all dat + batch_size: batch size + shuffle: shuffle the data set or no + + Returns: + an iterator + """ n = data.y.shape[0] if n < batch_size: num_batches = 1 diff --git a/sbijax/_src/mcmc/__init__.py b/sbijax/_src/mcmc/__init__.py new file mode 100644 index 0000000..d26d3dc --- /dev/null +++ b/sbijax/_src/mcmc/__init__.py @@ -0,0 +1,3 @@ +from sbijax._src.mcmc.nuts import sample_with_nuts +from sbijax._src.mcmc.sample import mcmc_diagnostics +from sbijax._src.mcmc.slice import sample_with_slice diff --git a/sbijax/mcmc/irmh.py b/sbijax/_src/mcmc/irmh.py similarity index 76% rename from sbijax/mcmc/irmh.py rename to sbijax/_src/mcmc/irmh.py index 05e0c6a..ca48942 100644 --- a/sbijax/mcmc/irmh.py +++ b/sbijax/_src/mcmc/irmh.py @@ -8,28 +8,17 @@ def sample_with_imh( rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs ): - """ - Sample from a distribution using the indepdendent Metropolis-Hastings - sampler. + r"""Draw sanokes using the indepdendent Metropolis-Hastings sampler. - Parameters - ---------- - rng_seq: hk.PRNGSequence - a hk.PRNGSequence - lp: Callable - the logdensity you wish to sample from - prior: Callable - a function that returns a prior sample - n_chains: int - number of chains to sample - n_samples: int - number of samples per chain - n_warmup: int - number of samples to discard + Args: + rng_seq: a hk.PRNGSequence + lp: the logdensity you wish to sample from + prior: a function that returns a prior sample + n_chains: number of chains to sample + n_samples: number of samples per chain + n_warmup: number of samples to discard - Returns - ------- - jnp.ndarrau + Returns: a JAX array of dimension n_samples \times n_chains \times len_theta """ diff --git a/sbijax/mcmc/mala.py b/sbijax/_src/mcmc/mala.py similarity index 75% rename from sbijax/mcmc/mala.py rename to sbijax/_src/mcmc/mala.py index 7a84e67..0c2cedd 100644 --- a/sbijax/mcmc/mala.py +++ b/sbijax/_src/mcmc/mala.py @@ -8,27 +8,17 @@ def sample_with_mala( rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs ): - """ - Sample from a distribution using the MALA sampler. - - Parameters - ---------- - rng_seq: hk.PRNGSequence - a hk.PRNGSequence - lp: Callable - the logdensity you wish to sample from - prior: Callable - a function that returns a prior sample - n_chains: int - number of chains to sample - n_samples: int - number of samples per chain - n_warmup: int - number of samples to discard - - Returns - ------- - jnp.ndarrau + r"""Sample from a distribution using the MALA sampler. + + Args: + rng_key: a hk.PRNGSequence + lp: the logdensity you wish to sample from + prior: a function that returns a prior sample + n_chains: number of chains to sample + n_samples: number of samples per chain + n_warmup: number of samples to discard + + Returns: a JAX array of dimension n_samples \times n_chains \times len_theta """ diff --git a/sbijax/mcmc/nuts.py b/sbijax/_src/mcmc/nuts.py similarity index 77% rename from sbijax/mcmc/nuts.py rename to sbijax/_src/mcmc/nuts.py index 2781ec1..518897b 100644 --- a/sbijax/mcmc/nuts.py +++ b/sbijax/_src/mcmc/nuts.py @@ -8,27 +8,23 @@ def sample_with_nuts( rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs ): - """ - Sample from a distribution using the No-U-Turn sampler. + r"""Sample from a distribution using the No-U-Turn sampler. - Parameters - ---------- - rng_seq: hk.PRNGSequence - a hk.PRNGSequence - lp: Callable - the logdensity you wish to sample from - prior: Callable - a function that returns a prior sample - n_chains: int - number of chains to sample - n_samples: int - number of samples per chain - n_warmup: int - number of samples to discard + Args: + rng_key: hk.PRNGSequence + a hk.PRNGSequence + lp: Callable + the logdensity you wish to sample from + prior: Callable + a function that returns a prior sample + n_chains: int + number of chains to sample + n_samples: int + number of samples per chain + n_warmup: int + number of samples to discard - Returns - ------- - jnp.ndarrau + Returns: a JAX array of dimension n_samples \times n_chains \times len_theta """ diff --git a/sbijax/mcmc/rmh.py b/sbijax/_src/mcmc/rmh.py similarity index 76% rename from sbijax/mcmc/rmh.py rename to sbijax/_src/mcmc/rmh.py index 51deb61..fb6eb47 100644 --- a/sbijax/mcmc/rmh.py +++ b/sbijax/_src/mcmc/rmh.py @@ -9,27 +9,17 @@ def sample_with_rmh( rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs ): - """ - Sample from a distribution using the Rosenbluth-Metropolis-Hastings sampler. + r"""Sample from a distribution using Rosenbluth-Metropolis-Hastings sampler. - Parameters - ---------- - rng_seq: hk.PRNGSequence - a hk.PRNGSequence - lp: Callable - the logdensity you wish to sample from - prior: Callable - a function that returns a prior sample - n_chains: int - number of chains to sample - n_samples: int - number of samples per chain - n_warmup: int - number of samples to discard + Args: + rng_key: a hk.PRNGSequence + lp: the logdensity you wish to sample from + prior: a function that returns a prior sample + n_chains: number of chains to sample + n_samples: number of samples per chain + n_warmup: number of samples to discard - Returns - ------- - jnp.ndarrau + Returns: a JAX array of dimension n_samples \times n_chains \times len_theta """ diff --git a/sbijax/mcmc/sample.py b/sbijax/_src/mcmc/sample.py similarity index 75% rename from sbijax/mcmc/sample.py rename to sbijax/_src/mcmc/sample.py index a72d076..f16e403 100644 --- a/sbijax/mcmc/sample.py +++ b/sbijax/_src/mcmc/sample.py @@ -4,23 +4,17 @@ def mcmc_diagnostics(samples): - """ - Computes MCMC diagnostics. + r"""Computes MCMC diagnostics. Compute effective sample sizes and R-hat for each parameter of a set of MCMC chains. - Parameters - ---------- - samples: jnp.ndarray - a JAX array of dimension n_samples \times n_chains \times n_dim + Args: + samples: a JAX array of dimension n_samples \times n_chains \times n_dim - Returns - ------- - tuple + Returns: a tuple of jnp.ndarrays with ess and rhat estimates. """ - n_theta = samples.shape[-1] esses = [0] * n_theta rhats = [0] * n_theta diff --git a/sbijax/mcmc/slice.py b/sbijax/_src/mcmc/slice.py similarity index 70% rename from sbijax/mcmc/slice.py rename to sbijax/_src/mcmc/slice.py index 82f914d..de64cd5 100644 --- a/sbijax/mcmc/slice.py +++ b/sbijax/_src/mcmc/slice.py @@ -17,30 +17,19 @@ def sample_with_slice( step_size=1, **kwargs, ): - """ - Sample from a distribution using the No-U-Turn sampler. + r"""Sample from a distribution using the No-U-Turn sampler. - Parameters - ---------- - rng_seq: hk.PRNGSequence - a hk.PRNGSequence - lp: Callable - the logdensity you wish to sample from - prior: Callable - a function that returns a prior sample - n_chains: int - number of chains to sample - n_samples: int - number of samples per chain - n_warmup: int - number of samples to discard + Args: + rng_seq: a hk.PRNGSequence + lp: the logdensity you wish to sample from + prior: a function that returns a prior sample + n_chains: number of chains to sample + n_samples: number of samples per chain + n_warmup: number of samples to discard - Returns - ------- - jnp.ndarrau + Returns: a JAX array of dimension n_samples \times n_chains \times len_theta """ - init_key, rng_key = jr.split(rng_key) initial_states = _slice_init(init_key, n_chains, prior) diff --git a/sbijax/_src/nn/__init__.py b/sbijax/_src/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sbijax/_src/nn/make_snass_networks.py b/sbijax/_src/nn/make_snass_networks.py new file mode 100644 index 0000000..81dda4a --- /dev/null +++ b/sbijax/_src/nn/make_snass_networks.py @@ -0,0 +1,90 @@ +from typing import Callable, List + +import haiku as hk +import jax + +from sbijax._src.nn.snass_net import SNASSNet +from sbijax._src.nn.snasss_net import SNASSSNet + + +def make_snass_net( + summary_net_dimensions: List[int], + critic_net_dimensions: List[int], + activation: Callable[[jax.Array], jax.Array] = jax.nn.relu, +): + """Create a critic network for SNASS. + + Args: + summary_net_dimensions: a list of integers representing + the dimensionalities of the summary network. The _last_ dimension + determines the dimensionality of the summary statistic + critic_net_dimensions: a list of integers representing the + dimensionality of the critic network. The _last_ dimension + needs to be 1. + activation: a jax activation function + + Returns: + a network that can be used within a SNASS posterior estimator + """ + + @hk.without_apply_rng + @hk.transform + def _net(method, **kwargs): + summary_net = hk.nets.MLP( + output_sizes=summary_net_dimensions, activation=activation + ) + critic_net = hk.nets.MLP( + output_sizes=critic_net_dimensions, activation=activation + ) + net = SNASSNet(summary_net=summary_net, critic_net=critic_net) + return net(method, **kwargs) + + return _net + + +def make_snasss_net( + summary_net_dimensions: List[int], + sec_summary_net_dimensions: List[int], + critic_net_dimensions: List[int], + activation: Callable[[jax.Array], jax.Array] = jax.nn.relu, +): + """Create a critic network for SNASSS. + + Args: + summary_net_dimensions: a list of integers representing + the dimensionalities of the summary network. The _last_ dimension + determines the dimensionality of the summary statistic + sec_summary_net_dimensions: list of integers representing + the dimensionalities of the summary network. The _last_ dimension + determines the dimensionality of the second summary statistic and + it should be smaller than the last dimension of the + first summary net. + critic_net_dimensions: a list of integers representing the + dimensionality of the critic network. The _last_ dimension + needs to be 1. + activation: a jax activation function + + Returns: + a network that can be used within a SNASSS posterior estimator + """ + + @hk.without_apply_rng + @hk.transform + def _net(method, **kwargs): + summary_net = hk.nets.MLP( + output_sizes=summary_net_dimensions, activation=activation + ) + sec_summary_net = hk.nets.MLP( + output_sizes=sec_summary_net_dimensions, activation=activation + ) + critic_net = hk.nets.MLP( + output_sizes=critic_net_dimensions, activation=activation + ) + net = SNASSSNet( + summary_net=summary_net, + sec_summary_net=sec_summary_net, + critic_net=critic_net, + ) + return net(method, **kwargs) + + return _net diff --git a/sbijax/_src/nn/snass_net.py b/sbijax/_src/nn/snass_net.py new file mode 100644 index 0000000..25c72bd --- /dev/null +++ b/sbijax/_src/nn/snass_net.py @@ -0,0 +1,70 @@ +from typing import Callable, List + +import haiku as hk +import jax +from jax import numpy as jnp + + +# pylint: disable=missing-function-docstring,missing-class-docstring +# pydocstyle: disable=D102 +class SNASSNet(hk.Module): + """A network for SNASS.""" + + def __init__( + self, + summary_net_dimensions: List[int] = None, + critic_net_dimensions: List[int] = None, + summary_net: Callable = None, + critic_net: Callable = None, + ): + """Constructs a SNASSNet. + + Can be used either by providing network dimensions or haiku modules. + + Args: + summary_net_dimensions: a list of integers representing + the dimensionalities of the summary network. The _last_ + dimension determines the dimensionality of the summary statistic + critic_net_dimensions: a list of integers representing the + dimensionality of the critic network. The _last_ dimension + needs to be 1. + summary_net: a haiku MLP with trailing dimension being the + dimensionality of the summary statistic + critic_net: : a haiku MLP with a trailing dimension of 1 + """ + super().__init__() + if summary_net_dimensions is not None: + assert critic_net_dimensions is not None + assert summary_net is None + assert critic_net is None + self._summary_net = hk.nets.MLP( + output_sizes=summary_net_dimensions, activation=jax.nn.relu + ) + self._critic_net = hk.nets.MLP( + output_sizes=critic_net_dimensions, activation=jax.nn.relu + ) + else: + assert summary_net is not None + assert critic_net is not None + self._summary_net = summary_net + self._critic_net = critic_net + + def __call__(self, method: str, **kwargs): + """Apply the network. + + Args: + method: the method to be called + kwargs: keyword arguments to be passed to the called method + """ + return getattr(self, "_" + method)(**kwargs) + + def _forward(self, y, theta): + s = self._summary(y) + c = self._critic(s, theta) + return s, c + + def _summary(self, y): + return self._summary_net(y) + + def _critic(self, y, theta): + return self._critic_net(jnp.concatenate([y, theta], axis=-1)) diff --git a/sbijax/_src/nn/snasss_net.py b/sbijax/_src/nn/snasss_net.py new file mode 100644 index 0000000..5bc3626 --- /dev/null +++ b/sbijax/_src/nn/snasss_net.py @@ -0,0 +1,73 @@ +from typing import Callable, List + +import haiku as hk +import jax +from jax import numpy as jnp + +from sbijax._src.nn.snass_net import SNASSNet + + +# pylint: disable=missing-function-docstring,missing-class-docstring +# pylint: disable=too-many-arguments +class SNASSSNet(SNASSNet): + """A network for SNASSS.""" + + def __init__( + self, + summary_net_dimensions: List[int] = None, + sec_summary_net_dimensions: List[int] = None, + critic_net_dimensions: List[int] = None, + summary_net: Callable = None, + sec_summary_net: Callable = None, + critic_net: Callable = None, + ): + """Constructs a SNASSSNet. + + Can be used either by providing network dimensions or haiku modules. + + Args: + summary_net_dimensions: a list of integers representing + the dimensionalities of the summary network. The _last_ + dimension determines the dimensionality of the summary statistic + sec_summary_net_dimensions: a list of integers representing + the dimensionalities of the second summary network. The _last_ + should be 1. + critic_net_dimensions: a list of integers representing the + dimensionality of the critic network. The _last_ dimension + needs to be 1. + summary_net: a haiku MLP with trailing dimension being the + dimensionality of the summary statistic + sec_summary_net: a haiku MLP with trailing dimension of 1 + critic_net: : a haiku MLP with a trailing dimension of 1 + """ + super().__init__( + summary_net_dimensions, + critic_net_dimensions, + summary_net, + critic_net, + ) + if sec_summary_net_dimensions is not None: + assert sec_summary_net is None + self._sec_summary_net = hk.nets.MLP( + output_sizes=sec_summary_net_dimensions, activation=jax.nn.relu + ) + else: + self._sec_summary_net = sec_summary_net + + def __call__(self, method: str, **kwargs): + """Apply the network. + + Args: + method: the method to be called + kwargs: keyword arguments to be passed to the called method + """ + return getattr(self, "_" + method)(**kwargs) + + def _forward(self, y, theta): + s = self._summary(y) + s2 = self._secondary_summary(s, theta) + c = self._critic(s2, y[:, [0]]) + return s, s2, c + + def _secondary_summary(self, y, theta): + return self._sec_summary_net(jnp.concatenate([y, theta], axis=-1)) diff --git a/sbijax/snass.py b/sbijax/_src/snass.py similarity index 95% rename from sbijax/snass.py rename to sbijax/_src/snass.py index ac37c9f..9968c4f 100644 --- a/sbijax/snass.py +++ b/sbijax/_src/snass.py @@ -7,9 +7,9 @@ from jax import numpy as jnp from jax import random as jr -from sbijax.generator import DataLoader -from sbijax.nn.early_stopping import EarlyStopping -from sbijax.snl import SNL +from sbijax._src.generator import DataLoader +from sbijax._src.snl import SNL +from sbijax._src.util.early_stopping import EarlyStopping def _jsd_summary_loss(params, rng, apply_fn, **batch): @@ -39,6 +39,13 @@ class SNASS(SNL): """ def __init__(self, model_fns, density_estimator, snass_net): + """Construct a SNASS object. + + Args: + model_fns: tuple + density_estimator: maf + snass_net: mlp + """ super().__init__(model_fns, density_estimator) self.sc_net = snass_net @@ -64,7 +71,7 @@ def fit( n_iter: maximal number of training iterations per round batch_size: batch size used for training the model percentage_data_as_validation_set: percentage of the simulated data - that is used for valitation and early stopping + that is used for validation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation kwargs: keyword arguments with sampler specific parameters. For @@ -77,7 +84,6 @@ def fit( Returns: tuple of parameters and a tuple of the training information """ - itr_key, rng_key = jr.split(rng_key) train_iter, val_iter = self.as_iterators( itr_key, data, batch_size, percentage_data_as_validation_set @@ -209,7 +215,7 @@ def sample_posterior( n_warmup=1_000, **kwargs, ): - """Sample from the approximate posterior. + r"""Sample from the approximate posterior. Args: rng_key: a random key @@ -229,7 +235,6 @@ def sample_posterior( an array of samples from the posterior distribution of dimension (n_samples \times p) """ - observable = jnp.atleast_2d(observable) summary = self.sc_net.apply( params["s_params"], method="summary", y=observable diff --git a/sbijax/snasss.py b/sbijax/_src/snasss.py similarity index 95% rename from sbijax/snasss.py rename to sbijax/_src/snasss.py index 82242fd..96f0043 100644 --- a/sbijax/snasss.py +++ b/sbijax/_src/snasss.py @@ -7,9 +7,9 @@ from jax import numpy as jnp from jax import random as jr -from sbijax.generator import DataLoader -from sbijax.nn.early_stopping import EarlyStopping -from sbijax.snl import SNL +from sbijax._src.generator import DataLoader +from sbijax._src.snl import SNL +from sbijax._src.util.early_stopping import EarlyStopping def _sample_unit_sphere(rng_key, n, dim): @@ -63,6 +63,13 @@ class SNASSS(SNL): """ def __init__(self, model_fns, density_estimator, summary_net): + """Construct a SNASSS object. + + Args: + model_fns: tuple + density_estimator: maf + summary_net: snass network + """ super().__init__(model_fns, density_estimator) self.sc_net = summary_net @@ -81,14 +88,14 @@ def fit( """Fit a SNASSS model. Args: - rng_seq: a hk.PRNGSequence + rng_key: a hk.PRNGSequence data: data set obtained from calling `simulate_data_and_possibly_append` optimizer: an optax optimizer object n_iter: maximal number of training iterations per round batch_size: batch size used for training the model percentage_data_as_validation_set: percentage of the simulated data - that is used for valitation and early stopping + that is used for validation and early stopping n_early_stopping_patience: number of iterations of no improvement of training the flow before stopping optimisation kwargs: keyword arguments with sampler specific parameters. For @@ -101,7 +108,6 @@ def fit( Returns: tuple of parameters and a tuple of the training information """ - itr_key, rng_key = jr.split(rng_key) train_iter, val_iter = self.as_iterators( itr_key, data, batch_size, percentage_data_as_validation_set @@ -233,7 +239,7 @@ def sample_posterior( n_warmup=1_000, **kwargs, ): - """Sample from the approximate posterior. + r"""Sample from the approximate posterior. Args: rng_key: a random key @@ -253,7 +259,6 @@ def sample_posterior( an array of samples from the posterior distribution of dimension (n_samples \times p) """ - observable = jnp.atleast_2d(observable) summary = self.sc_net.apply( params["s_params"], method="summary", y=observable diff --git a/sbijax/snl.py b/sbijax/_src/snl.py similarity index 67% rename from sbijax/snl.py rename to sbijax/_src/snl.py index 97b20fb..92480bf 100644 --- a/sbijax/snl.py +++ b/sbijax/_src/snl.py @@ -8,12 +8,16 @@ from jax import numpy as jnp from jax import random as jr -from sbijax._sne_base import SNE -from sbijax.mcmc import mcmc_diagnostics, sample_with_nuts, sample_with_slice -from sbijax.mcmc.irmh import sample_with_imh -from sbijax.mcmc.mala import sample_with_mala -from sbijax.mcmc.rmh import sample_with_rmh -from sbijax.nn.early_stopping import EarlyStopping +from sbijax._src._sne_base import SNE +from sbijax._src.mcmc import ( + mcmc_diagnostics, + sample_with_nuts, + sample_with_slice, +) +from sbijax._src.mcmc.irmh import sample_with_imh +from sbijax._src.mcmc.mala import sample_with_mala +from sbijax._src.mcmc.rmh import sample_with_rmh +from sbijax._src.util.early_stopping import EarlyStopping # pylint: disable=too-many-arguments,unused-argument @@ -35,41 +39,30 @@ def fit( n_early_stopping_patience=10, **kwargs, ): - """ - Fit a SNL model - - Parameters - ---------- - rng_seq: hk.PRNGSequence - a hk.PRNGSequence - data: NamedTuple - data set obtained from calling `simulate_data_and_possibly_append` - optimizer: optax.Optimizer - an optax optimizer object - n_iter: - maximal number of training iterations per round - batch_size: int - batch size used for training the model - percentage_data_as_validation_set: - percentage of the simulated data that is used for valitation and - early stopping - n_early_stopping_patience: int - number of iterations of no improvement of training the flow - before stopping optimisation - kwargs: keyword arguments with sampler specific parameters. For slice - sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval - - Returns - ------- - Tuple[pytree, Tuple] + """Fit a SNL model. + + Args: + rng_key: a hk.PRNGSequence + data: data set obtained from calling + `simulate_data_and_possibly_append` + optimizer: an optax optimizer object + n_iter: maximal number of training iterations per round + batch_size: batch size used for training the model + percentage_data_as_validation_set: percentage of the simulated data + that is used for valitation and early stopping + n_early_stopping_patience: number of iterations of no improvement of + training the flow before stopping optimisation + kwargs: keyword arguments with sampler specific parameters. + For slice sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps + - n_doubling: number of doubling steps of the interval + - step_size: step size of the initial interval + + Returns: returns a tuple of parameters and a tuple of the training information """ - itr_key, rng_key = jr.split(rng_key) train_iter, val_iter = self.as_iterators( itr_key, data, batch_size, percentage_data_as_validation_set @@ -174,39 +167,26 @@ def simulate_data_and_possibly_append( n_warmup=1_000, **kwargs, ): - """ - Simulate data from the posterior and append it to an existing data set - (if provided) - - Parameters - ---------- - rng_key: jax.PRNGKey - a random key - params: pytree - a dictionary of neural network parameters - observable: jnp.ndarray - an observation - data: NamedTuple - existing data set - n_simulations: int - number of newly simulated data - n_chains: int - number of MCMC chains - n_samples: int - number of sa les to draw in total - n_warmup: int - number of draws to discared - kwargs: keyword arguments - dictionary of ey value pairs passed to `sample_posterior`. - The following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps (int) - - n_doubling: number of doubling steps of the interval (int) - - step_size: step size of the initial interval (float) - - Returns - ------- - NamedTuple: + """Simulate data from the prior or posterior. + + Args: + rng_key: a random key + params: a dictionary of neural network parameters + observable: an observation + data: existing data set + n_simulations: number of newly simulated data + n_chains: number of MCMC chains + n_samples: number of sa les to draw in total + n_warmup: number of draws to discared + kwargs: keyword arguments + dictionary of ey value pairs passed to `sample_posterior`. + The following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps (int) + - n_doubling: number of doubling steps of the interval (int) + - step_size: step size of the initial interval (float) + + Returns: returns a NamedTuple of two axis, y and theta """ return super().simulate_data_and_possibly_append( @@ -232,37 +212,26 @@ def sample_posterior( n_warmup=1_000, **kwargs, ): - """ - Sample from the approximate posterior - - Parameters - ---------- - rng_key: jax.PRNGKey - a random key - params: pytree - a pytree of parameter for the model - observable: jnp.Array - observation to condition on - n_chains: int - number of MCMC chains - n_samples: int - number of samples per chain - n_warmup: int - number of samples to discard - kwargs: keyword arguments with sampler specific parameters. For slice - sampling the following arguments are possible: - - sampler: either 'nuts', 'slice' or None (defaults to nuts) - - n_thin: number of thinning steps - - n_doubling: number of doubling steps of the interval - - step_size: step size of the initial interval - - Returns - ------- - chex.Array + r"""Sample from the approximate posterior. + + Args: + rng_key: a random key + params: a pytree of parameter for the model + observable: observation to condition on + n_chains: number of MCMC chains + n_samples: number of samples per chain + n_warmup: number of samples to discard + kwargs: keyword arguments with sampler specific parameters. For + slice sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps + - n_doubling: number of doubling steps of the interval + - step_size: step size of the initial interval + + Returns: an array of samples from the posterior distribution of dimension (n_samples \times p) """ - observable = jnp.atleast_2d(observable) return self._sample_posterior( rng_key, diff --git a/sbijax/snl_test.py b/sbijax/_src/snl_test.py similarity index 99% rename from sbijax/snl_test.py rename to sbijax/_src/snl_test.py index 4789338..ed2d703 100644 --- a/sbijax/snl_test.py +++ b/sbijax/_src/snl_test.py @@ -9,7 +9,7 @@ from surjectors.nn import make_mlp from surjectors.util import make_alternating_binary_mask -from sbijax import SNL +from sbijax._src.snl import SNL def prior_model_fns(): diff --git a/sbijax/snp.py b/sbijax/_src/snp.py similarity index 83% rename from sbijax/snp.py rename to sbijax/_src/snp.py index ae2170e..c822c12 100644 --- a/sbijax/snp.py +++ b/sbijax/_src/snp.py @@ -8,19 +8,25 @@ from jax import random as jr from jax import scipy as jsp -from sbijax._sne_base import SNE -from sbijax.nn.early_stopping import EarlyStopping +from sbijax._src._sne_base import SNE +from sbijax._src.util.early_stopping import EarlyStopping # pylint: disable=too-many-arguments,unused-argument class SNP(SNE): - """ - Sequential neural posterior estimation + """Sequential neural posterior estimation. - From the Greenberg paper + References: + .. [1] """ def __init__(self, model_fns, density_estimator): + """Construct an SNP object. + + Args: + model_fns: tuple + density_estimator: maf + """ super().__init__(model_fns, density_estimator) self.n_round = 0 @@ -37,37 +43,24 @@ def fit( n_atoms=10, **kwargs, ): + """Fit an SNPE model. + + Args: + rng_key: a hk.PRNGSequence + data: data set obtained from calling + `simulate_data_and_possibly_append` + optimizer: an optax optimizer object + n_iter: maximal number of training iterations per round + batch_size: batch size used for training the model + percentage_data_as_validation_set: percentage of the simulated + data that is used for valitation and early stopping + n_early_stopping_patience: number of iterations of no improvement + of training the flow before stopping optimisation + n_atoms: number of atoms to approximate the proposal posterior + + Returns: + a tuple of parameters and a tuple of the training information """ - Fit an SNPE model - - Parameters - ---------- - rng_seq: hk.PRNGSequence - a hk.PRNGSequence - data: NamedTuple - data set obtained from calling `simulate_data_and_possibly_append` - optimizer: optax.Optimizer - an optax optimizer object - n_iter: - maximal number of training iterations per round - batch_size: int - batch size used for training the model - percentage_data_as_validation_set: - percentage of the simulated data that is used for valitation and - early stopping - n_early_stopping_patience: int - number of iterations of no improvement of training the flow - before stopping optimisation - n_atoms : int - number of atoms to approximate the proposal posterior - - Returns - ------- - Tuple[pytree, Tuple] - returns a tuple of parameters and a tuple of the training - information - """ - itr_key, rng_key = jr.split(rng_key) train_iter, val_iter = self.as_iterators( itr_key, data, batch_size, percentage_data_as_validation_set @@ -239,27 +232,18 @@ def body_fn(i, rng_key): def sample_posterior( self, rng_key, params, observable, *, n_samples=4_000, **kwargs ): - """ - Sample from the approximate posterior - - Parameters - ---------- - rng_key: jax.PRNGKey - a random key - params: pytree - a pytree of parameter for the model - observable: jnp.Array - observation to condition on - n_samples: int - number of samples to draw - - Returns - ------- - chex.Array + r"""Sample from the approximate posterior. + + Args: + rng_key: a random key + params: a pytree of parameter for the model + observable: observation to condition on + n_samples: number of samples to draw + + Returns: an array of samples from the posterior distribution of dimension (n_samples \times p) """ - observable = jnp.atleast_2d(observable) thetas = None diff --git a/sbijax/snp_test.py b/sbijax/_src/snp_test.py similarity index 98% rename from sbijax/snp_test.py rename to sbijax/_src/snp_test.py index c6e4837..ace28c0 100644 --- a/sbijax/snp_test.py +++ b/sbijax/_src/snp_test.py @@ -7,7 +7,7 @@ from surjectors.nn import make_mlp from surjectors.util import make_alternating_binary_mask -from sbijax import SNP +from sbijax._src.snp import SNP def prior_model_fns(): diff --git a/sbijax/_src/util/__init__.py b/sbijax/_src/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sbijax/nn/early_stopping.py b/sbijax/_src/util/early_stopping.py similarity index 75% rename from sbijax/nn/early_stopping.py rename to sbijax/_src/util/early_stopping.py index 57f8f3c..e6b4a29 100644 --- a/sbijax/nn/early_stopping.py +++ b/sbijax/_src/util/early_stopping.py @@ -5,9 +5,7 @@ # pylint: disable=missing-function-docstring @dataclasses.dataclass class EarlyStopping: - """ - Early stopping of neural network training - """ + """Early stopping of neural network training.""" min_delta: float = 0 patience: int = 0 @@ -16,12 +14,25 @@ class EarlyStopping: should_stop: bool = False def reset(self): + """Reset the object. + + Returns: + self + """ self.best_metric = float("inf") self.patience_count = 0 self.should_stop = False return self def update(self, metric): + """Update the stopping criterion. + + Args: + metric: the tracjed metric as float + + Returns: + tuple + """ if ( math.isinf(self.best_metric) or self.best_metric - metric > self.min_delta diff --git a/sbijax/mcmc/__init__.py b/sbijax/mcmc/__init__.py deleted file mode 100644 index b83d7e2..0000000 --- a/sbijax/mcmc/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from sbijax.mcmc.nuts import sample_with_nuts -from sbijax.mcmc.sample import mcmc_diagnostics -from sbijax.mcmc.slice import sample_with_slice diff --git a/sbijax/nn/__init__.py b/sbijax/nn/__init__.py index e69de29..1ca8c87 100644 --- a/sbijax/nn/__init__.py +++ b/sbijax/nn/__init__.py @@ -0,0 +1,3 @@ +"""Neural network module.""" + +from sbijax._src.nn.make_snass_networks import make_snass_net, make_snasss_net diff --git a/sbijax/nn/snass_net.py b/sbijax/nn/snass_net.py deleted file mode 100644 index 47b26a9..0000000 --- a/sbijax/nn/snass_net.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Callable, List - -import haiku as hk -import jax -from jax import numpy as jnp - - -# pylint: disable=missing-function-docstring,missing-class-docstring -class SNASSNet(hk.Module): - """ - A network for SNASS - """ - - def __init__( - self, - summary_net_dimensions: List[int] = None, - critic_net_dimensions: List[int] = None, - summary_net: Callable = None, - critic_net: Callable = None, - ): - super().__init__() - if summary_net_dimensions is not None: - assert critic_net_dimensions is not None - assert summary_net is None - assert critic_net is None - self._summary = hk.nets.MLP( - output_sizes=summary_net_dimensions, activation=jax.nn.relu - ) - self._critic = hk.nets.MLP( - output_sizes=critic_net_dimensions, activation=jax.nn.relu - ) - else: - assert summary_net is not None - assert critic_net is not None - self._summary = summary_net - self._critic = critic_net - - def __call__(self, method, **kwargs): - return getattr(self, method)(**kwargs) - - def forward(self, y, theta): - s = self.summary(y) - c = self.critic(s, theta) - return s, c - - def summary(self, y): - return self._summary(y) - - def critic(self, y, theta): - return self._critic(jnp.concatenate([y, theta], axis=-1)) diff --git a/sbijax/nn/snasss_net.py b/sbijax/nn/snasss_net.py deleted file mode 100644 index d283123..0000000 --- a/sbijax/nn/snasss_net.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Callable, List - -import haiku as hk -import jax -from jax import numpy as jnp - - -# pylint: disable=missing-function-docstring,missing-class-docstring, -# pylint: too-many-arguments -class SNASSSNet(hk.Module): - def __init__( - self, - summary_net_dimensions: List[int] = None, - sec_summary_net_dimensions: List[int] = None, - critic_net_dimensions: List[int] = None, - summary_net: Callable = None, - sec_summary_net: Callable = None, - critic_net: Callable = None, - ): - super().__init__() - if summary_net_dimensions is not None: - assert critic_net_dimensions is not None - assert summary_net is None - assert critic_net is None - self._summary = hk.nets.MLP( - output_sizes=summary_net_dimensions, activation=jax.nn.relu - ) - self._secondary_summary = hk.nets.MLP( - output_sizes=sec_summary_net_dimensions, activation=jax.nn.relu - ) - self._critic = hk.nets.MLP( - output_sizes=critic_net_dimensions, activation=jax.nn.relu - ) - else: - assert summary_net is not None - assert critic_net is not None - self._summary = summary_net - self._secondary_summary = sec_summary_net - self._critic = critic_net - - def __call__(self, method, **kwargs): - return getattr(self, method)(**kwargs) - - def forward(self, y, theta): - s = self.summary(y) - s2 = self.secondary_summary(s, theta) - c = self.critic(s2, y[:, [0]]) - return s, s2, c - - def summary(self, y): - return self._summary(y) - - def secondary_summary(self, y, theta): - return self._secondary_summary(jnp.concatenate([y, theta], axis=-1)) - - def critic(self, y, theta): - return self._critic(jnp.concatenate([y, theta], axis=-1)) From f8b5f7f5abb7cafdb44d1e9dc1a4884270d29bf9 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 22 Feb 2024 21:00:21 +0100 Subject: [PATCH 2/3] blub --- .github/workflows/ci.yaml | 2 +- sbijax/_src/mcmc/irmh.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d018580..2d66cae 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -73,7 +73,7 @@ jobs: pip install hatch - name: Build package run: | - pip install jaxlib jax + pip install jaxlib==0.4.24 jax==0.4.24 - name: Run tests run: | hatch run test:test diff --git a/sbijax/_src/mcmc/irmh.py b/sbijax/_src/mcmc/irmh.py index ca48942..deaa559 100644 --- a/sbijax/_src/mcmc/irmh.py +++ b/sbijax/_src/mcmc/irmh.py @@ -8,7 +8,7 @@ def sample_with_imh( rng_key, lp, prior, *, n_chains=4, n_samples=2_000, n_warmup=1_000, **kwargs ): - r"""Draw sanokes using the indepdendent Metropolis-Hastings sampler. + r"""Draw samples using the indepdendent Metropolis-Hastings sampler. Args: rng_seq: a hk.PRNGSequence From f8d40cc9b9f01ffba9b816001c737225875c4518 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 22 Feb 2024 21:05:34 +0100 Subject: [PATCH 3/3] blub --- sbijax/_src/_sbi_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbijax/_src/_sbi_base.py b/sbijax/_src/_sbi_base.py index 4a1b6a8..91aaf0e 100644 --- a/sbijax/_src/_sbi_base.py +++ b/sbijax/_src/_sbi_base.py @@ -19,7 +19,7 @@ def __init__(self, model_fns): self._len_theta = len(self.prior_sampler_fn(seed=jr.PRNGKey(123))) @abc.abstractmethod - def sample_posterior(self, rng_key: jr.KeyArray, **kwargs): + def sample_posterior(self, rng_key, **kwargs): """Sample from the posterior distribution. Args: