Skip to content

Commit

Permalink
begin on a unified nested sampling interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
bd-j committed Aug 26, 2024
1 parent ade88bd commit 8609875
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 31 deletions.
2 changes: 1 addition & 1 deletion prospect/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def lnprobfn(theta, model=None, observations=None, sps=None,
ndof = np.sum([obs["ndof"] for obs in observations])
lnnull = np.zeros(ndof) - 1e18 # -np.infty
else:
lnnull = -np.inf
lnnull = -np.infty

# --- Calculate prior probability and exit if not within prior ---
lnp_prior = model.prior_product(theta, nested=nested)
Expand Down
122 changes: 92 additions & 30 deletions prospect/fitting/nested.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,97 @@
import sys, time
import numpy as np
from numpy.random import normal, multivariate_normal

try:
import nestle
except(ImportError):
pass

from .fitting import lnprobfn

try:
import dynesty
from dynesty.utils import *
from dynesty.dynamicsampler import _kld_error
except(ImportError):
pass


__all__ = ["run_nestle_sampler", "run_dynesty_sampler"]
__all__ = ["run_nested", "run_dynesty_sampler"]



def run_nested(observations, model, sps, lnprobfn=lnprobfn, fitter="dynesty",
pool=None, nested_target_n_effective=10000, **kwargs):

go = time.time()

# --- Ultranest ---
if fitter == "ultranest":
# TODO: what about vector parameters

from ultranest import ReactiveNestedSampler
sampler = ReactiveNestedSampler(model.free_params,
lnprobfn,
model.prior_transform)
result = sampler.run(**kwargs)

points = np.array(result['weighted_samples']['points'])
log_w = np.log(np.array(result['weighted_samples']['weights']))
log_like = np.array(result['weighted_samples']['logl'])

# --- Nautilus ---
if fitter == "nautilus":
from nautilus import Prior, Sampler

# we have to use the nautilus prior objects
# TODO: no we don't!
prior = Prior()
for k in params.param_names:
pr = params.priors[k]
if pr.kind == "Normal":
prior.add_parameter(k, dist=norm(pr.params['mean'], pr.params['sigma']))
else:
prior.add_parameter(k, dist=(pr.params['mini'], pr.params['maxi']))
sampler = Sampler(prior, lnprobfn, n_live=1000)
sampler.run(verbose=verbose)

points, log_w, log_like = sampler.posterior()

# --- Dynesty ---
if fitter == "dynesty":
from dynesty import DynamicNestedSampler

sampler_kwargs=dict(nlive=1000,
bound='multi',
sample="unif",
walks=48)

sampler = DynamicNestedSampler(lnprobfn,
model.prior_transform,
model.ndim,
**sampler_kwargs)
sampler.run_nested(n_effective=1000, dlogz_init=0.05)

points = sampler.results["samples"]
log_w = sampler.results["logwt"]
log_like = sampler.results["logl"]

# --- Nestle ---
if fitter == "nestle":
import nestle

sampler_kwargs=dict(method=nestle_method,
npoints=nestle_npoints,
callback=callback,
maxcall=nestle_maxcall,
update_interval=nestle_update_interval)

result = nestle.sample(lnprobfn,
model.prior_transform,
model.ndim,
**sampler_kwargs)

def run_nestle_sampler(lnprobfn, model, verbose=True,
callback=None,
nestle_method='multi', nestle_npoints=200,
nestle_maxcall=int(1e6), nestle_update_interval=None,
**kwargs):
points = result["samples"]
log_w = result["logwt"]
log_like = result["logl"]

result = nestle.sample(lnprobfn, model.prior_transform, model.ndim,
method=nestle_method, npoints=nestle_npoints,
callback=callback, maxcall=nestle_maxcall,
update_interval=nestle_update_interval)
return result
dur = time.time() - go

return (points, log_w, log_like)


def run_dynesty_sampler(lnprobfn, prior_transform, ndim,
Expand Down Expand Up @@ -67,20 +129,24 @@ def run_dynesty_sampler(lnprobfn, prior_transform, ndim,
print_progress=True,
**extras):

from dynesty import DynamicNestedSampler


# instantiate sampler
dsampler = dynesty.DynamicNestedSampler(lnprobfn, prior_transform, ndim,
bound=nested_bound,
sample=nested_sample,
walks=nested_walks,
bootstrap=nested_bootstrap,
update_interval=nested_update_interval,
pool=pool, queue_size=queue_size, use_pool=use_pool
)
dsampler = DynamicNestedSampler(lnprobfn, prior_transform, ndim,
bound=nested_bound,
sample=nested_sample,
walks=nested_walks,
bootstrap=nested_bootstrap,
update_interval=nested_update_interval,
pool=pool, queue_size=queue_size, use_pool=use_pool
)

# generator for initial nested sampling
ncall = dsampler.ncall
niter = dsampler.it - 1
tstart = time.time()

for results in dsampler.sample_initial(nlive=nested_nlive_init,
dlogz=nested_dlogz_init,
maxcall=nested_maxcall_init,
Expand Down Expand Up @@ -191,9 +257,5 @@ def run_dynesty_sampler(lnprobfn, prior_transform, ndim,
# We're done!
break

ndur = time.time() - tstart
if verbose:
print('done dynesty (dynamic) in {0}s'.format(ndur))

return dsampler.results

0 comments on commit 8609875

Please sign in to comment.