Skip to content

Commit

Permalink
Model fitting table (#2093)
Browse files Browse the repository at this point in the history
* table subcomponent: support for adding additional columns
* results table for model fitting
* units, fixed (and eventually uncert) will be hidden in the UI by default
* store uncertainties in table and expose in model_fitting.get_model_component
* test coverage
* add docs
* skip table logging for cube fits with note added to UI when enabling cube fit
  • Loading branch information
kecnry authored Mar 24, 2023
1 parent 5510486 commit b6dd6c1
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
New Features
------------

* Model fitting results are logged in a table within the plugin [#2093].

Cubeviz
^^^^^^^

Expand Down
11 changes: 11 additions & 0 deletions docs/specviz/export_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ To extract all of the model parameters:
myparams
where the ``model_label`` parameter identifies which model should be returned.

Alternatively, the table of logged parameter values in the model fitting plugin can be exported to
an :ref:`astropy table <astropy:astropy-table>`
by calling :meth:`~jdaviz.core.template_mixin.TableMixin.export_table` (see :ref:`plugin-apis`):

.. code-block:: python
model_fitting = specviz.plugins['Model Fitting']
model_fitting.export_table()
.. _specviz-export-markers:

Markers Table
Expand Down
5 changes: 5 additions & 0 deletions docs/specviz/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ show the fitted value of each parameter rather than the initial value, and
will additionally show the standard deviation uncertainty of the fitted
parameter value if the parameter was not set to be fixed to the initial value.

Parameter values for each fitting run are stored in the plugin table.
To export the table into the notebook via the API, call
:meth:`~jdaviz.core.template_mixin.TableMixin.export_table`
(see :ref:`plugin-apis`).

.. seealso::

:ref:`Export Models <specviz-export-model>`
Expand Down
1 change: 1 addition & 0 deletions jdaviz/configs/default/plugins/markers/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, *args, **kwargs):

self.table.headers_avail = headers
self.table.headers_visible = headers
self.table._default_values_by_colname = _default_table_values

# subscribe to mouse events on any new viewers
self.hub.subscribe(self, ViewerAddedMessage, handler=self._on_viewer_added)
Expand Down
63 changes: 57 additions & 6 deletions jdaviz/configs/default/plugins/model_fitting/model_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
DatasetSelectMixin,
DatasetSpectralSubsetValidMixin,
AutoTextField,
AddResultsMixin)
AddResultsMixin,
TableMixin)
from jdaviz.core.custom_traitlets import IntHandleEmpty
from jdaviz.core.user_api import PluginUserApi
from jdaviz.configs.default.plugins.model_fitting.fitting_backend import fit_model_to_spectrum
Expand All @@ -39,7 +40,7 @@ def __init__(self, value, unit=None):
@tray_registry('g-model-fitting', label="Model Fitting", viewer_requirements='spectrum')
class ModelFitting(PluginTemplateMixin, DatasetSelectMixin,
SpectralSubsetSelectMixin, DatasetSpectralSubsetValidMixin,
AddResultsMixin):
AddResultsMixin, TableMixin):
"""
See the :ref:`Model Fitting Plugin Documentation <specviz-model-fitting>` for more details.
Expand All @@ -63,6 +64,7 @@ class ModelFitting(PluginTemplateMixin, DatasetSelectMixin,
* :meth:`set_model_component`
* :meth:`reestimate_model_parameters`
* ``equation`` (:class:`~jdaviz.core.template_mixin.AutoTextField`)
* :meth:`equation_components`
* ``cube_fit``
Only exposed for Cubeviz. Whether to fit the model to the cube instead of to the
collapsed spectrum.
Expand Down Expand Up @@ -154,6 +156,16 @@ def __init__(self, *args, **kwargs):
self.residuals = AutoTextField(self, 'residuals_label', 'residuals_label_default',
'residuals_label_auto', 'residuals_label_invalid_msg')

headers = ['model', 'data_label', 'spectral_subset', 'equation']
if self.config == 'cubeviz':
headers += ['spatial_subset', 'cube_fit']

self.table.headers_avail = headers
self.table.headers_visible = headers
# when model parameters are added as columns, only show the value columns by default
# (other columns can be show in the dropdown by the user)
self.table._new_col_visible = lambda colname: colname.split(':')[-1] not in ('unit', 'fixed', 'uncert', 'std') # noqa

# set the filter on the viewer options
self._update_viewer_filters()

Expand All @@ -165,10 +177,11 @@ def user_api(self):
expose += ['spectral_subset', 'model_component', 'poly_order', 'model_component_label',
'model_components', 'create_model_component', 'remove_model_component',
'get_model_component', 'set_model_component', 'reestimate_model_parameters',
'equation', 'add_results', 'residuals_calculate', 'residuals']
'equation', 'equation_components',
'add_results', 'residuals_calculate', 'residuals']
if self.config == "cubeviz":
expose += ['cube_fit']
expose += ['calculate_fit']
expose += ['calculate_fit', 'clear_table', 'export_table']
return PluginUserApi(self, expose=expose)

def _param_units(self, param, model_type=None):
Expand Down Expand Up @@ -531,6 +544,7 @@ def get_model_component(self, model_component_label, parameter=None):
comp = {"model_type": model_component['model_type'],
"parameters": {p['name']: {'value': p['value'],
'unit': p['unit'],
'std': p.get('std', np.nan),
'fixed': p['fixed']} for p in model_component['parameters']}} # noqa

if parameter is not None:
Expand Down Expand Up @@ -630,6 +644,13 @@ def model_components(self):
"""
return [x["id"] for x in self.component_models]

@property
def equation_components(self):
"""
List of the labels of model components in the current equation
"""
return re.split('[+*/-]', self.equation.value)

def vue_add_model(self, event):
self.create_model_component()

Expand Down Expand Up @@ -688,9 +709,39 @@ def calculate_fit(self, add_data=True):
raise ValueError(f"spectral subset '{self.spectral_subset.selected}' {subset_range} is outside data range of '{self.dataset.selected}' {spec_range}") # noqa

if self.cube_fit:
return self._fit_model_to_cube(add_data=add_data)
ret = self._fit_model_to_cube(add_data=add_data)
else:
return self._fit_model_to_spectrum(add_data=add_data)
ret = self._fit_model_to_spectrum(add_data=add_data)

if ret is None: # pragma: no cover
# something went wrong in the fitting call and (hopefully) already raised a warning,
# but we don't have anything to add to the table
return ret

if self.cube_fit:
# cube fits are currently unsupported in tables
return ret

row = {'model': self.results_label if add_data else '',
'data_label': self.dataset_selected,
'spectral_subset': self.spectral_subset_selected,
'equation': self.equation.value}
if self.app.config == 'cubeviz':
row['spatial_subset'] = self.spatial_subset_selected
row['cube_fit'] = self.cube_fit

equation_components = self.equation_components
for comp_ind, comp in enumerate(equation_components):
for param_name, param_dict in self.get_model_component(comp).get('parameters', {}).items(): # noqa
colprefix = f"{comp}:{param_name}_{comp_ind}"
row[colprefix] = param_dict.get('value')
row[f"{colprefix}:unit"] = param_dict.get('unit')
row[f"{colprefix}:fixed"] = param_dict.get('fixed')
row[f"{colprefix}:std"] = param_dict.get('std')

self.table.add_item(row)

return ret

def vue_apply(self, event):
self.calculate_fit()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@
></v-switch>
</v-row>

<v-row v-if="cube_fit">
<span class="v-messages v-messages__message text--secondary">
Note: cube fit results are not logged to table.
</span>
</v-row>

<plugin-add-results
:label.sync="results_label"
:label_default="results_label_default"
Expand Down Expand Up @@ -267,6 +273,9 @@
If fit is not sufficiently converged, click Fit Model again to run additional iterations.
</span>
</v-row>

<j-plugin-section-header>Results History</j-plugin-section-header>
<jupyter-widget :widget="table_widget"></jupyter-widget>
</div>
</j-tray-plugin>
</template>
Expand Down
39 changes: 39 additions & 0 deletions jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,42 @@ def test_cube_fitting_backend(cubeviz_helper, unc, tmp_path):
data_mask = cubeviz_helper.app.data_collection["fitted_cube.fits[MASK]"]
flux_mask = data_mask.get_component("flux")
assert_array_equal(flux_mask.data, mask)


@pytest.mark.filterwarnings(r"ignore:Model is linear in parameters.*")
@pytest.mark.filterwarnings(r"ignore:The fit may be unsuccessful.*")
def test_results_table(specviz_helper, spectrum1d):
data_label = 'test'
specviz_helper.load_data(spectrum1d, data_label=data_label)

mf = specviz_helper.plugins['Model Fitting']
mf.create_model_component('Linear1D')

mf.add_results.label = 'linear model'
mf.calculate_fit(add_data=True)
mf_table = mf.export_table()
assert len(mf_table) == 1
assert mf_table['equation'][-1] == 'L'
assert mf_table.colnames == ['model', 'data_label', 'spectral_subset', 'equation',
'L:slope_0', 'L:slope_0:unit',
'L:slope_0:fixed', 'L:slope_0:std',
'L:intercept_0', 'L:intercept_0:unit',
'L:intercept_0:fixed', 'L:intercept_0:std']

mf.create_model_component('Gaussian1D')
mf.add_results.label = 'composite model'
mf.calculate_fit(add_data=True)
mf_table = mf.export_table()
assert len(mf_table) == 2
assert mf_table['equation'][-1] == 'L+G'
assert mf_table.colnames == ['model', 'data_label', 'spectral_subset', 'equation',
'L:slope_0', 'L:slope_0:unit',
'L:slope_0:fixed', 'L:slope_0:std',
'L:intercept_0', 'L:intercept_0:unit',
'L:intercept_0:fixed', 'L:intercept_0:std',
'G:amplitude_1', 'G:amplitude_1:unit',
'G:amplitude_1:fixed', 'G:amplitude_1:std',
'G:mean_1', 'G:mean_1:unit',
'G:mean_1:fixed', 'G:mean_1:std',
'G:stddev_1', 'G:stddev_1:unit',
'G:stddev_1:fixed', 'G:stddev_1:std']
32 changes: 30 additions & 2 deletions jdaviz/core/template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2293,6 +2293,8 @@ class Table(PluginSubcomponent):
"""
template_file = __file__, "../components/plugin_table.vue"

_default_values_by_colname = {}

headers_visible = List([]).tag(sync=True) # list of strings
headers_avail = List([]).tag(sync=True) # list of strings
items = List().tag(sync=True) # list of dictionaries, pass single dict to add_row
Expand All @@ -2301,6 +2303,21 @@ def __init__(self, plugin, *args, **kwargs):
self._qtable = None
super().__init__(plugin, 'Table', *args, **kwargs)

def default_value_for_column(self, colname=None, value=None):
if colname in self._default_values_by_colname:
return self._default_values_by_colname.get(colname)
if isinstance(value, (tuple, list)):
return [self.default_value_for_column(value=v) for v in value]
if isinstance(value, (float, int)):
return np.nan
if isinstance(value, str):
return ''
return None

@staticmethod
def _new_col_visible(colname):
return True

def add_item(self, item):
"""
Add an item/row to the table.
Expand Down Expand Up @@ -2351,10 +2368,21 @@ def float_precision(column, item):
if self._qtable is None:
self._qtable = QTable([item])
else:
# NOTE: this does not support adding columns that did not exist in the first
# call to add_row since the last call to clear_table
# add any missing columns with a default value for all previous rows
for colname, value in item.items():
if colname in self._qtable.colnames:
continue
default_value = self.default_value_for_column(colname=colname,
value=value)
self._qtable.add_column(default_value, name=colname)

self._qtable.add_row(item)

missing_headers = [k for k in item.keys() if k not in self.headers_avail]
if len(missing_headers):
self.headers_avail = self.headers_avail + missing_headers
self.headers_visible = self.headers_visible + [m for m in missing_headers if self._new_col_visible(m)] # noqa

# clean data to show in the UI
self.items = self.items + [{k: json_safe(k, v) for k, v in item.items()}]

Expand Down

0 comments on commit b6dd6c1

Please sign in to comment.