Skip to content

Commit

Permalink
fix spectral extraction previews on unit change (#3157)
Browse files Browse the repository at this point in the history
  • Loading branch information
kecnry authored Oct 7, 2024
1 parent a9cd7b5 commit 770dfd1
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 23 deletions.
7 changes: 6 additions & 1 deletion jdaviz/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,12 @@ def merge_func(spectral_region): # noop
return self._get_multi_mask_subset_definition(subset_state)

def _get_display_unit(self, axis):
if self._jdaviz_helper is None or self._jdaviz_helper.plugins.get('Unit Conversion') is None: # noqa
if self._jdaviz_helper is None:
# cannot access either the plugin or the spectrum viewer.
# Plugins that access the unit at this point will need to
# detect that they are set to unitless and attempt again later.
return ''
elif self._jdaviz_helper.plugins.get('Unit Conversion') is None: # noqa
# fallback on native units (unit conversion is not enabled)
if axis == 'spectral':
sv = self.get_viewer(self._jdaviz_helper._default_spectrum_viewer_reference_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jdaviz.core.validunits import check_if_unit_is_per_solid_angle
from jdaviz.configs.cubeviz.plugins.parsers import _return_spectrum_with_correct_units
from jdaviz.configs.cubeviz.plugins.viewers import WithSliceIndicator
from jdaviz.utils import _eqv_pixar_sr


__all__ = ['SpectralExtraction']
Expand Down Expand Up @@ -109,8 +110,8 @@ class SpectralExtraction(PluginTemplateMixin, ApertureSubsetSelectMixin,

results_units = Unicode().tag(sync=True)
spectrum_y_units = Unicode().tag(sync=True)
flux_unit = Unicode().tag(sync=True)
sb_unit = Unicode().tag(sync=True)
flux_units = Unicode().tag(sync=True)
sb_units = Unicode().tag(sync=True)

aperture_method_items = List().tag(sync=True)
aperture_method_selected = Unicode('Center').tag(sync=True)
Expand Down Expand Up @@ -316,31 +317,37 @@ def _update_mark_scale(self, *args):
else:
self.background.scale_factor = self.slice_spectral_value/self.reference_spectral_value

def _on_global_display_unit_changed(self, msg={}):

if msg.axis == 'spectral_y':
self.spectrum_y_units = str(msg.unit)

# a 'flux' and 'sb' message should be recieved back to back from
# the unit conversion plugin, so don't need to sync them immediatley
# within each message recieved
def _on_global_display_unit_changed(self, msg=None):
if msg is None:
self.flux_units = str(self.app._get_display_unit('flux'))
self.sb_units = str(self.app._get_display_unit('sb'))
self.spectrum_y_units = str(self.app._get_display_unit('spectral_y'))
elif msg.axis == 'flux':
self.flux_unit = str(msg.unit)
self.flux_units = str(msg.unit)
elif msg.axis == 'sb':
self.sb_unit = str(msg.unit)
self.sb_units = str(msg.unit)
elif msg.axis == 'spectral_y':
self.spectrum_y_units = str(msg.unit)
# no need to update results_units as separate messages will have been
# sent by unit conversion for flux and/or sb.
# updates to spectrum_y_units will trigger updating the extraction preview
return
else:
# ignore
return

# and set results_units, which depends on function selected
self._update_units_on_function_selection()
# update results_units based on flux_units, sb_units, and currently selected function
self._update_results_units()

@observe('function_selected')
def _update_units_on_function_selection(self, *args):

def _update_results_units(self, *args):
# NOTE this is also called by _on_global_display_unit_changed
# after flux_units and/or sb_units is set.
# results_units is ONLY used for the warning in the UI, so does not
# need to trigger an update to the preview
if self.function_selected.lower() == 'sum':
self.results_units = self.flux_unit
self.results_units = self.flux_units
else:
self.results_units = self.sb_unit
self.results_units = self.sb_units

@observe('function_selected', 'aperture_method_selected')
def _update_aperture_method_on_function_change(self, *args):
Expand Down Expand Up @@ -550,10 +557,12 @@ def _return_extracted(self, cube, wcs, collapsed_nddata):
return collapsed_spec

def _preview_x_from_extracted(self, extracted):
return extracted.spectral_axis.value
return extracted.spectral_axis

def _preview_y_from_extracted(self, extracted):
return extracted.flux.value
# TODO: use extracted.meta.get('PIXAR_SR') once populated
return extracted.flux.to(self.spectrum_y_units,
equivalencies=_eqv_pixar_sr(self.dataset.selected_obj.meta.get('PIXAR_SR', 1.0))) # noqa:

@with_spinner()
def extract(self, return_bg=False, add_data=True, **kwargs):
Expand Down Expand Up @@ -737,8 +746,15 @@ def _clear_marks(self):
'wavelength_dependent', 'bg_wavelength_dependent', 'reference_spectral_value',
'function_selected',
'aperture_method_selected',
'spectrum_y_units',
'previews_temp_disabled')
def _live_update_marks(self, event={}):
if self.spectrum_y_units == '':
# ensure that units are populated
# which in turn will make a call back here
# from the observe on spectrum_y_units
self._on_global_display_unit_changed(None)
return
self._update_marks(event)

@skip_if_not_tray_instance()
Expand All @@ -753,7 +769,7 @@ def _update_marks(self, event={}):
self.marks['bg_extract'].visible = self.active_step == 'bg' and self.bg_selected != self.background.default_text # noqa
self.marks['extract'].visible = True

# _live_update will skip if no updates since last active
# _live_update_extract will skip if no updates since last active
self._live_update_extract(event)

@skip_if_no_updates_since_last_active()
Expand Down
16 changes: 16 additions & 0 deletions jdaviz/core/marks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,24 @@ def hub(self):
return self.viewer.hub

def update_xy(self, x, y):
# If x and y are not in the previous units, they should be provided as quantities
if hasattr(x, 'value'):
xunit = x.unit
x = x.value
else:
xunit = None
self.x = np.asarray(x)
if xunit is not None:
self.xunit = u.Unit(xunit)

if hasattr(y, 'value'):
yunit = y.unit
y = y.value
else:
yunit = None
self.y = np.asarray(y)
if yunit is not None:
self.yunit = u.Unit(yunit)

def append_xy(self, x, y):
self.x = np.append(self.x, x)
Expand Down

0 comments on commit 770dfd1

Please sign in to comment.