Skip to content

Commit

Permalink
try to store an unstructured chain in the output as well as a structu…
Browse files Browse the repository at this point in the history
…red chain.
  • Loading branch information
bd-j committed Aug 9, 2024
1 parent d76f841 commit ca3948e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ from prospect.fitting import fit_model
output = fit_model(observations, model, sps, **config)
```

Another change is that spectral response functions (i.e. calibration vectors) are now handled by specialized sub-classes of these `Observation` classes. See `spectra`_ for details.

.. _spectra: docs/spectra.rst
Another change is that spectral response functions (i.e. calibration vectors) are now handled by specialized sub-classes of these `Observation` classes. See the [spectroscopy docs](docs/spectra.rst) for details.

Finally, the output chain or samples is now stored as a structured array, where
each row corresponds to a sample, and each column is a parameter (possibly
Expand Down
16 changes: 11 additions & 5 deletions prospect/io/write_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
to HDF5 files as well as to pickles.
"""

import os, time, warnings
from copy import deepcopy
import pickle, json, base64
import warnings
import pickle, json
import numpy as np
from numpy.lib.recfunctions import structured_to_unstructured, unstructured_to_structured

try:
import h5py
_has_h5py_ = True
Expand Down Expand Up @@ -223,6 +224,11 @@ def write_sampling_h5(hf, chain, extras):
sdat = hf.create_group('sampling')

sdat.create_dataset('chain', data=chain)
try:
uchain = structured_to_unstructured(chain)
sdat.create_dataset("unstructured_chain", data=uchain)
except:
pass
for k, v in extras.items():
try:
sdat.create_dataset(k, data=v)
Expand Down Expand Up @@ -279,8 +285,7 @@ def chain_to_struct(chain, model=None, names=None, **extras):
struct :
A structured ndarray of parameter values.
"""
indict = isinstance(chain, dict)
if indict:
if isinstance(chain, dict):
return dict_to_struct(chain)
else:
n = np.prod(chain.shape[:-1])
Expand All @@ -296,6 +301,7 @@ def chain_to_struct(chain, model=None, names=None, **extras):

dt += [(str(k), "<f8") for k in extras.keys()]

# TODO: replace with unstructured_to_structured
struct = np.zeros(n, dtype=np.dtype(dt))
for i, p in enumerate(names):
if model is not None:
Expand Down

0 comments on commit ca3948e

Please sign in to comment.