Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

blubCleanup #20

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ jobs:
pip install hatch
- name: Build package
run: |
pip install jaxlib jax
pip install jaxlib==0.4.24 jax==0.4.24
- name: Run tests
run: |
hatch run test:test
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ jobs:
matrix:
python-version: [3.11]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install pypa/build
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ repos:
args: ["--ignore-missing-imports"]
files: "(sbijax|examples)"

- repo: https://github.com/jorisroovers/gitlint
rev: v0.18.0
- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
hooks:
- id: gitlint
- id: gitlint-ci
- id: pydocstyle
additional_dependencies: ["toml"]
17 changes: 17 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
version: 2

build:
os: ubuntu-22.04
tools:
python: "3.11"

sphinx:
builder: html
configuration: docs/conf.py
fail_on_warning: false

python:
install:
- method: pip
path: .
- requirements: docs/requirements.txt
18 changes: 6 additions & 12 deletions examples/bivariate_gaussian_snasss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from surjectors.util import unstack

from sbijax import SNASSS
from sbijax.nn.snasss_net import SNASSSNet
from sbijax.nn import make_snasss_net

W = jr.normal(jr.PRNGKey(0), (2, 10))

Expand Down Expand Up @@ -74,23 +74,17 @@ def _flow(method, **kwargs):
return td


def make_critic(dim):
@hk.without_apply_rng
@hk.transform
def _net(method, **kwargs):
net = SNASSSNet([64, 64, dim], [64, 64, 1], [64, 64, 1])
return net(method, **kwargs)

return _net


def run():
y_observed = jnp.array([[2.0, -2.0]]) @ W

prior_simulator_fn, prior_logdensity_fn = prior_model_fns()
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn

estim = SNASSS(fns, make_model(2), make_critic(2))
estim = SNASSS(
fns,
make_model(2),
make_snasss_net([64, 64, 2], [64, 64, 1], [64, 64, 1]),
)
optimizer = optax.adam(1e-3)

data, params = None, {}
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ invalid-name,missing-module-docstring,R0801,E0633

[tool.bandit]
skips = ["B101"]

[tool.pydocstyle]
convention= 'google'
match = '^sbijax/.*/((?!_test).)*\.py'
12 changes: 6 additions & 6 deletions sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
__version__ = "0.1.5"


from sbijax.abc.rejection_abc import RejectionABC
from sbijax.abc.smc_abc import SMCABC
from sbijax.snass import SNASS
from sbijax.snasss import SNASSS
from sbijax.snl import SNL
from sbijax.snp import SNP
from sbijax._src.abc.rejection_abc import RejectionABC
from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.snass import SNASS
from sbijax._src.snasss import SNASSS
from sbijax._src.snl import SNL
from sbijax._src.snp import SNP
File renamed without changes.
20 changes: 10 additions & 10 deletions sbijax/_sbi_base.py → sbijax/_src/_sbi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
# pylint: disable=too-many-instance-attributes,unused-argument,
# pylint: disable=too-few-public-methods
class SBI(abc.ABC):
"""
SBI base class
"""
"""SBI base class."""

def __init__(self, model_fns):
"""Construct an SBI object.

Args:
model_fns: tuple
"""
self.prior_sampler_fn, self.prior_log_density_fn = model_fns[0]
self.simulator_fn = model_fns[1]
self._len_theta = len(self.prior_sampler_fn(seed=jr.PRNGKey(123)))

@abc.abstractmethod
def sample_posterior(self, rng_key, **kwargs):
"""
Sample from the posterior distribution
"""Sample from the posterior distribution.

Parameters
----------
rng_key: jax.PRNGKey
a random key
kwargs: keyword arguments with sampler specific parameters
Args:
rng_key: a random key
kwargs: keyword arguments with sampler specific parameters
"""
117 changes: 51 additions & 66 deletions sbijax/_sne_base.py → sbijax/_src/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
from jax import numpy as jnp
from jax import random as jr

from sbijax import generator
from sbijax._sbi_base import SBI
from sbijax.generator import named_dataset
from sbijax._src._sbi_base import SBI
from sbijax._src.generator import as_batch_iterators, named_dataset


# pylint: disable=too-many-arguments,unused-argument
# pylint: disable=too-many-function-args,arguments-differ
class SNE(SBI, ABC):
"""
Sequential neural estimation
"""
"""Sequential neural estimation base class."""

def __init__(self, model_fns, density_estimator):
"""Construct an SNE object.

Args:
model_fns: tuple
density_estimator: maf
"""
super().__init__(model_fns)
self.model = density_estimator
self.n_total_simulations = 0
Expand All @@ -30,31 +33,19 @@ def simulate_data_and_possibly_append(
n_simulations=1000,
**kwargs,
):
"""
Simulate data from the posteriorand append it to an existing data set
(if provided)

Parameters
----------
rng_key: jax.PRNGKey
a random key
params: pytree
a dictionary of neural network parameters
observable: jnp.ndarray
an observation
data: NamedTuple
existing data set
n_simulations: int
number of newly simulated data
kwargs: keyword arguments
dictionary of ey value pairs passed to `sample_posterior`

Returns
-------
NamedTuple:
"""Simulate data from the prior or posterior and append.

Args:
rng_key: a random key
params: a dictionary of neural network parameters
observable: an observation
data: existing data set
n_simulations: number of newly simulated data
kwargs: dictionary of ey value pairs passed to `sample_posterior`

Returns:
returns a NamedTuple of two axis, y and theta
"""

observable = jnp.atleast_2d(observable)
new_data, diagnostics = self.simulate_data(
rng_key,
Expand All @@ -78,31 +69,21 @@ def simulate_data(
n_simulations=1000,
**kwargs,
):
r"""Simulate data from the posterior or prior and append.

Args:
rng_key: a random key
params:a dictionary of neural network parameters. If None, will
draw from prior. If parameters given, will draw from amortized
posterior using 'observable;
observable: an observation. Needs to be gfiven if posterior draws
are desired
n_simulations: number of newly simulated data
kwargs: dictionary of ey value pairs passed to `sample_posterior`

Returns:
a NamedTuple of two axis, y and theta
"""
Simulate data from the posterior or prior and append it to an
existing data set (if provided)

Parameters
----------
rng_key: jax.PRNGKey
a random key
params: Optional[pytree]
a dictionary of neural network parameters. If None, will draw from
prior. If parameters given, will draw from amortized posterior
using 'observable;
observable: Optional[jnp.ndarray]
an observation. Needs to be gfiven if posterior draws are desired
n_simulations: int
number of newly simulated data
kwargs: keyword arguments
dictionary of ey value pairs passed to `sample_posterior`

Returns
-------
NamedTuple:
returns a NamedTuple of two axis, y and theta
"""

sample_key, rng_key = jr.split(rng_key)
if params is None or len(params) == 0:
diagnostics = None
Expand Down Expand Up @@ -140,21 +121,15 @@ def simulate_data(

@staticmethod
def stack_data(data, also_data):
"""
Stack two data sets.
"""Stack two data sets.

Parameters
----------
data: NamedTuple
one data set
also_data: : NamedTuple
Args:
data: one data set
also_data: another data set

Returns
-------
NamedTuple:
Returns:
returns the stack of the two data sets
"""

if data is None:
return also_data
if also_data is None:
Expand All @@ -166,8 +141,18 @@ def stack_data(data, also_data):
def as_iterators(
self, rng_key, data, batch_size, percentage_data_as_validation_set
):
"""Convert the data set to an iterable for training"""
return generator.as_batch_iterators(
"""Convert the data set to an iterable for training.

Args:
rng_key: random key
data: tuple
batch_size: integer
percentage_data_as_validation_set: fraction

Returns:
a batch iterator
"""
return as_batch_iterators(
rng_key,
data,
batch_size,
Expand Down
Empty file added sbijax/_src/abc/__init__.py
Empty file.
56 changes: 30 additions & 26 deletions sbijax/abc/rejection_abc.py → sbijax/_src/abc/rejection_abc.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
from typing import Callable, Tuple

from jax import numpy as jnp
from jax import random as jr

from sbijax._sbi_base import SBI
from sbijax._src._sbi_base import SBI


# pylint: disable=too-many-instance-attributes,too-many-arguments
# pylint: disable=too-many-locals,too-few-public-methods
# pylint: disable=too-many-locals,too-few-public-methods,
class RejectionABC(SBI):
"""
Sisson et al. - Handbook of approximate Bayesian computation
"""Rejection approximate Bayesian computation.

Algorithm 4.1, "ABC Rejection Sampling Algorithm"
Implements algorithm~4.1 from [1].

References:
.. [1] Sisson, Scott A, et al. "Handbook of approximate Bayesian
computation". 2019
"""

def __init__(self, model_fns, summary_fn, kernel_fn):
def __init__(
self, model_fns: Tuple, summary_fn: Callable, kernel_fn: Callable
):
"""Constructs a RejectionABC object.

Args:
model_fns: tuple
summary_fn: summary statistice function
kernel_fn: a kernel function to compute similarities
"""
super().__init__(model_fns)
self.kernel_fn = kernel_fn
self.summary_fn = summary_fn
Expand All @@ -29,31 +43,21 @@ def sample_posterior(
h,
**kwargs,
):
"""
Sample from the approximate posterior
r"""Sample from the approximate posterior.

Parameters
----------
rng_key: jax.PRNGKey
a random key
observable: jnp.Array
observation to condition on
n_samples: int
number of samples to draw for each parameter
n_simulations_per_theta: int
number of simulations for each paramter sample
K: double
normalisation parameter
h: double
kernel scale
Args:
rng_key: a random key
observable: observation to condition on
n_samples: number of samples to draw for each parameter
n_simulations_per_theta: number of simulations for each parameter
sample
K: normalisation parameter
h: kernel scale

Returns
-------
chex.Array
Returns:
an array of samples from the posterior distribution of dimension
(n_samples \times p)
"""

observable = jnp.atleast_2d(observable)

thetas = None
Expand Down
Loading
Loading