Skip to content

Commit

Permalink
more code to implement different smaplers, including in output.
Browse files Browse the repository at this point in the history
  • Loading branch information
bd-j committed Aug 26, 2024
1 parent 355c256 commit d34f580
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 322 deletions.
122 changes: 59 additions & 63 deletions prospect/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from .minimizer import minimize_wrapper, minimizer_ball
from .ensemble import run_emcee_sampler
from .nested import run_dynesty_sampler
from .nested import run_nested_sampler, parse_nested_kwargs
from ..likelihood.likelihood import compute_chi, compute_lnlike


__all__ = ["lnprobfn", "fit_model",
"run_minimize", "run_emcee", "run_dynesty"
"run_minimize", "run_ensemble", "run_nested"
]


Expand Down Expand Up @@ -123,7 +123,8 @@ def wrap_lnp(lnpfn, observations, model, sps, **lnp_kwargs):


def fit_model(observations, model, sps, lnprobfn=lnprobfn,
optimize=False, emcee=False, dynesty=True, **kwargs):
optimize=False, emcee=False, nested_sampler="",
**kwargs):
"""Fit a model to observations using a number of different methods
Parameters
Expand Down Expand Up @@ -167,7 +168,7 @@ def fit_model(observations, model, sps, lnprobfn=lnprobfn,
+ ``hfile``: an open h5py.File file handle for writing result incrementally
Many additional emcee parameters can be provided here, see
:py:func:`run_emcee` for details.
:py:func:`run_ensemble` for details.
dynesty : bool (optional, default: True)
If ``True``, sample from the posterior using dynesty. Additonal
Expand All @@ -184,11 +185,7 @@ def fit_model(observations, model, sps, lnprobfn=lnprobfn,
# Make sure obs has required keys
[obs.rectify() for obs in observations]

if emcee & dynesty:
msg = ("Cannot run both emcee and dynesty fits "
"in a single call to fit_model")
raise(ValueError, msg)
if (not emcee) & (not dynesty) & (not optimize):
if (not bool(emcee)) & (not bool(nested_sampler)) & (not optimize):
msg = ("No sampling or optimization routine "
"specified by user; returning empty results")
warnings.warn(msg)
Expand All @@ -204,14 +201,16 @@ def fit_model(observations, model, sps, lnprobfn=lnprobfn,
output["optimization"] = (optres, topt)

if emcee:
run_sampler = run_emcee
elif dynesty:
run_sampler = run_dynesty
run_sampler = run_ensemble
elif nested_sampler:
run_sampler = run_nested
kwargs["fitter"] = nested_sampler
else:
return output

output["sampling"] = run_sampler(observations, model, sps,
lnprobfn=lnprobfn, **kwargs)
lnprobfn=lnprobfn,
**kwargs)
return output


Expand Down Expand Up @@ -305,8 +304,8 @@ def run_minimize(observations=None, model=None, sps=None, lnprobfn=lnprobfn,
return results, tm, best


def run_emcee(observations, model, sps, lnprobfn=lnprobfn,
hfile=None, initial_positions=None, **kwargs):
def run_ensemble(observations, model, sps, lnprobfn=lnprobfn,
hfile=None, initial_positions=None, **kwargs):
"""Run emcee, optionally including burn-in and convergence checking. Thin
wrapper on :py:class:`prospect.fitting.ensemble.run_emcee_sampler`
Expand Down Expand Up @@ -387,6 +386,7 @@ def run_emcee(observations, model, sps, lnprobfn=lnprobfn,
q = model.theta.copy()

postkwargs = {}
# Hack for MPI pools to access the global namespace
for item in ['observations', 'model', 'sps']:
val = eval(item)
if val is not None:
Expand All @@ -396,26 +396,32 @@ def run_emcee(observations, model, sps, lnprobfn=lnprobfn,
# Could try to make signatures for these two methods the same....
if initial_positions is not None:
raise NotImplementedError
meth = restart_emcee_sampler
t = time.time()
out = meth(lnprobfn, initial_positions, hdf5=hfile,
postkwargs=postkwargs, **kwargs)
go = time.time()
out = restart_emcee_sampler(lnprobfn, initial_positions,
hdf5=hfile,
postkwargs=postkwargs,
**kwargs)
sampler = out
ts = time.time() - t
ts = time.time() - go
else:
meth = run_emcee_sampler
t = time.time()
out = meth(lnprobfn, q, model, hdf5=hfile,
postkwargs=postkwargs, **kwargs)
sampler, burn_p0, burn_prob0 = out
ts = time.time() - t
go = time.time()
out = run_emcee_sampler(lnprobfn, q, model,
hdf5=hfile,
postkwargs=postkwargs,
**kwargs)
sampler, burn_loc0, burn_prob0 = out
ts = time.time() - go

return sampler, ts


def run_dynesty(observations, model, sps, lnprobfn=lnprobfn,
pool=None, nested_target_n_effective=10000, **kwargs):
"""Thin wrapper on :py:class:`prospect.fitting.nested.run_dynesty_sampler`
def run_nested(observations, model, sps,
lnprobfn=lnprobfn,
fitter="dynesty",
nested_nlive=1000,
nested_neff=1000,
**kwargs):
"""Thin wrapper on :py:class:`prospect.fitting.nested.run_nested_sampler`
Parameters
----------
Expand All @@ -436,43 +442,33 @@ def run_dynesty(observations, model, sps, lnprobfn=lnprobfn,
``model``, and ``sps`` as keywords. By default use the
:py:func:`lnprobfn` defined above.
Extra Parameters
--------
nested_bound: (optional, default: 'multi')
nested_sample: (optional, default: 'unif')
nested_nlive_init: (optional, default: 100)
nested_nlive_batch: (optional, default: 100)
nested_dlogz_init: (optional, default: 0.02)
nested_maxcall: (optional, default: None)
nested_walks: (optional, default: 25)
Returns
--------
result:
An instance of :py:class:`dynesty.results.Results`.
result: Dictionary
Will have keys:
* points : parameter location of the samples
* log_weight : ln of the weights of each sample
* log_like : ln of the likelihoods of each sample
t_wall : float
Duration of sampling in seconds of wall time.
"""
from dynesty.dynamicsampler import stopping_function, weight_function
nested_stop_kwargs = {"target_n_effective": nested_target_n_effective}

lnp = wrap_lnp(lnprobfn, observations, model, sps, nested=True)

# Need to deal with postkwargs...

t = time.time()
dynestyout = run_dynesty_sampler(lnp, model.prior_transform, model.ndim,
stop_function=stopping_function,
wt_function=weight_function,
nested_stop_kwargs=nested_stop_kwargs,
pool=pool, **kwargs)
ts = time.time() - t

return dynestyout, ts
# wrap the probability fiunction, making sure it's a likelihood
likelihood = wrap_lnp(lnprobfn, observations, model, sps, nested=True)

# which keywords do we have for this fitter?
ns_kwargs, nr_kwargs = parse_nested_kwargs(fitter=fitter,
**kwargs)

go = time.time()
output = run_nested_sampler(model,
likelihood,
fitter=fitter,
verbose=False,
nested_nlive=nested_nlive,
nested_neff=nested_neff,
nested_sampler_kwargs=ns_kwargs,
nested_run_kwargs=nr_kwargs)
ts = time.time() - go

return output, ts
Loading

0 comments on commit d34f580

Please sign in to comment.