Skip to content

Commit

Permalink
Cleanup (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Feb 22, 2024
1 parent 9f59b0e commit f2f9722
Show file tree
Hide file tree
Showing 37 changed files with 634 additions and 545 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ jobs:
pip install hatch
- name: Build package
run: |
pip install jaxlib jax
pip install jaxlib==0.4.24 jax==0.4.24
- name: Run tests
run: |
hatch run test:test
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ jobs:
matrix:
python-version: [3.11]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install pypa/build
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ repos:
args: ["--ignore-missing-imports"]
files: "(sbijax|examples)"

- repo: https://github.com/jorisroovers/gitlint
rev: v0.18.0
- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
hooks:
- id: gitlint
- id: gitlint-ci
- id: pydocstyle
additional_dependencies: ["toml"]
17 changes: 17 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
version: 2

build:
os: ubuntu-22.04
tools:
python: "3.11"

sphinx:
builder: html
configuration: docs/conf.py
fail_on_warning: false

python:
install:
- method: pip
path: .
- requirements: docs/requirements.txt
18 changes: 6 additions & 12 deletions examples/bivariate_gaussian_snasss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from surjectors.util import unstack

from sbijax import SNASSS
from sbijax.nn.snasss_net import SNASSSNet
from sbijax.nn import make_snasss_net

W = jr.normal(jr.PRNGKey(0), (2, 10))

Expand Down Expand Up @@ -74,23 +74,17 @@ def _flow(method, **kwargs):
return td


def make_critic(dim):
@hk.without_apply_rng
@hk.transform
def _net(method, **kwargs):
net = SNASSSNet([64, 64, dim], [64, 64, 1], [64, 64, 1])
return net(method, **kwargs)

return _net


def run():
y_observed = jnp.array([[2.0, -2.0]]) @ W

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SNASSS(fns, make_model(2), make_critic(2))
estim = SNASSS(
fns,
make_model(2),
make_snasss_net([64, 64, 2], [64, 64, 1], [64, 64, 1]),
)
optimizer = optax.adam(1e-3)

data, params = None, {}
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ invalid-name,missing-module-docstring,R0801,E0633

[tool.bandit]
skips = ["B101"]

[tool.pydocstyle]
convention= 'google'
match = '^sbijax/.*/((?!_test).)*\.py'
12 changes: 6 additions & 6 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
__version__ = "0.1.5"


from sbijax.abc.rejection_abc import RejectionABC
from sbijax.abc.smc_abc import SMCABC
from sbijax.snass import SNASS
from sbijax.snasss import SNASSS
from sbijax.snl import SNL
from sbijax.snp import SNP
from sbijax._src.abc.rejection_abc import RejectionABC
from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.snass import SNASS
from sbijax._src.snasss import SNASSS
from sbijax._src.snl import SNL
from sbijax._src.snp import SNP
File renamed without changes.
20 changes: 10 additions & 10 deletions sbijax/_sbi_base.py → sbijax/_src/_sbi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
# pylint: disable=too-many-instance-attributes,unused-argument,
# pylint: disable=too-few-public-methods
class SBI(abc.ABC):
"""
SBI base class
"""
"""SBI base class."""

def __init__(self, model_fns):
"""Construct an SBI object.
Args:
model_fns: tuple
"""
self.prior_sampler_fn, self.prior_log_density_fn = model_fns[0]
self.simulator_fn = model_fns[1]
self._len_theta = len(self.prior_sampler_fn(seed=jr.PRNGKey(123)))

@abc.abstractmethod
def sample_posterior(self, rng_key, **kwargs):
"""
Sample from the posterior distribution
"""Sample from the posterior distribution.
Parameters
----------
rng_key: jax.PRNGKey
a random key
kwargs: keyword arguments with sampler specific parameters
Args:
rng_key: a random key
kwargs: keyword arguments with sampler specific parameters
"""
117 changes: 51 additions & 66 deletions sbijax/_sne_base.py → sbijax/_src/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
from jax import numpy as jnp
from jax import random as jr

from sbijax import generator
from sbijax._sbi_base import SBI
from sbijax.generator import named_dataset
from sbijax._src._sbi_base import SBI
from sbijax._src.generator import as_batch_iterators, named_dataset


# pylint: disable=too-many-arguments,unused-argument
# pylint: disable=too-many-function-args,arguments-differ
class SNE(SBI, ABC):
"""
Sequential neural estimation
"""
"""Sequential neural estimation base class."""

def __init__(self, model_fns, density_estimator):
"""Construct an SNE object.
Args:
model_fns: tuple
density_estimator: maf
"""
super().__init__(model_fns)
self.model = density_estimator
self.n_total_simulations = 0
Expand All @@ -30,31 +33,19 @@ def simulate_data_and_possibly_append(
n_simulations=1000,
**kwargs,
):
"""
Simulate data from the posteriorand append it to an existing data set
(if provided)
Parameters
----------
rng_key: jax.PRNGKey
a random key
params: pytree
a dictionary of neural network parameters
observable: jnp.ndarray
an observation
data: NamedTuple
existing data set
n_simulations: int
number of newly simulated data
kwargs: keyword arguments
dictionary of ey value pairs passed to `sample_posterior`
Returns
-------
NamedTuple:
"""Simulate data from the prior or posterior and append.
Args:
rng_key: a random key
params: a dictionary of neural network parameters
observable: an observation
data: existing data set
n_simulations: number of newly simulated data
kwargs: dictionary of ey value pairs passed to `sample_posterior`
Returns:
returns a NamedTuple of two axis, y and theta
"""

observable = jnp.atleast_2d(observable)
new_data, diagnostics = self.simulate_data(
rng_key,
Expand All @@ -78,31 +69,21 @@ def simulate_data(
n_simulations=1000,
**kwargs,
):
r"""Simulate data from the posterior or prior and append.
Args:
rng_key: a random key
params:a dictionary of neural network parameters. If None, will
draw from prior. If parameters given, will draw from amortized
posterior using 'observable;
observable: an observation. Needs to be gfiven if posterior draws
are desired
n_simulations: number of newly simulated data
kwargs: dictionary of ey value pairs passed to `sample_posterior`
Returns:
a NamedTuple of two axis, y and theta
"""
Simulate data from the posterior or prior and append it to an
existing data set (if provided)
Parameters
----------
rng_key: jax.PRNGKey
a random key
params: Optional[pytree]
a dictionary of neural network parameters. If None, will draw from
prior. If parameters given, will draw from amortized posterior
using 'observable;
observable: Optional[jnp.ndarray]
an observation. Needs to be gfiven if posterior draws are desired
n_simulations: int
number of newly simulated data
kwargs: keyword arguments
dictionary of ey value pairs passed to `sample_posterior`
Returns
-------
NamedTuple:
returns a NamedTuple of two axis, y and theta
"""

sample_key, rng_key = jr.split(rng_key)
if params is None or len(params) == 0:
diagnostics = None
Expand Down Expand Up @@ -140,21 +121,15 @@ def simulate_data(

@staticmethod
def stack_data(data, also_data):
"""
Stack two data sets.
"""Stack two data sets.
Parameters
----------
data: NamedTuple
one data set
also_data: : NamedTuple
Args:
data: one data set
also_data: another data set
Returns
-------
NamedTuple:
Returns:
returns the stack of the two data sets
"""

if data is None:
return also_data
if also_data is None:
Expand All @@ -166,8 +141,18 @@ def stack_data(data, also_data):
def as_iterators(
self, rng_key, data, batch_size, percentage_data_as_validation_set
):
"""Convert the data set to an iterable for training"""
return generator.as_batch_iterators(
"""Convert the data set to an iterable for training.
Args:
rng_key: random key
data: tuple
batch_size: integer
percentage_data_as_validation_set: fraction
Returns:
a batch iterator
"""
return as_batch_iterators(
rng_key,
data,
batch_size,
Expand Down
Empty file added sbijax/_src/abc/__init__.py
Empty file.
56 changes: 30 additions & 26 deletions sbijax/abc/rejection_abc.py → sbijax/_src/abc/rejection_abc.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
from typing import Callable, Tuple

from jax import numpy as jnp
from jax import random as jr

from sbijax._sbi_base import SBI
from sbijax._src._sbi_base import SBI


# pylint: disable=too-many-instance-attributes,too-many-arguments
# pylint: disable=too-many-locals,too-few-public-methods
# pylint: disable=too-many-locals,too-few-public-methods,
class RejectionABC(SBI):
"""
Sisson et al. - Handbook of approximate Bayesian computation
"""Rejection approximate Bayesian computation.
Algorithm 4.1, "ABC Rejection Sampling Algorithm"
Implements algorithm~4.1 from [1].
References:
.. [1] Sisson, Scott A, et al. "Handbook of approximate Bayesian
computation". 2019
"""

def __init__(self, model_fns, summary_fn, kernel_fn):
def __init__(
self, model_fns: Tuple, summary_fn: Callable, kernel_fn: Callable
):
"""Constructs a RejectionABC object.
Args:
model_fns: tuple
summary_fn: summary statistice function
kernel_fn: a kernel function to compute similarities
"""
super().__init__(model_fns)
self.kernel_fn = kernel_fn
self.summary_fn = summary_fn
Expand All @@ -29,31 +43,21 @@ def sample_posterior(
h,
**kwargs,
):
"""
Sample from the approximate posterior
r"""Sample from the approximate posterior.
Parameters
----------
rng_key: jax.PRNGKey
a random key
observable: jnp.Array
observation to condition on
n_samples: int
number of samples to draw for each parameter
n_simulations_per_theta: int
number of simulations for each paramter sample
K: double
normalisation parameter
h: double
kernel scale
Args:
rng_key: a random key
observable: observation to condition on
n_samples: number of samples to draw for each parameter
n_simulations_per_theta: number of simulations for each parameter
sample
K: normalisation parameter
h: kernel scale
Returns
-------
chex.Array
Returns:
an array of samples from the posterior distribution of dimension
(n_samples \times p)
"""

observable = jnp.atleast_2d(observable)

thetas = None
Expand Down
Loading

0 comments on commit f2f9722

Please sign in to comment.