Skip to content

Commit

Permalink
fix polynomial regularization bug (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
bd-j committed Dec 11, 2023
1 parent 30d2bab commit dd50479
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions prospect/models/sedmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,8 @@ def spec_calibration(self, theta=None, obs=None, spec=None, **kwargs):
"""Implements a Chebyshev polynomial calibration model. This uses
least-squares to find the maximum-likelihood Chebyshev polynomial of a
certain order describing the ratio of the observed spectrum to the model
spectrum, conditional on all other parameters, using least squares. If
emission lines are being marginalized out, they are excluded from the
least-squares fit.
spectrum, conditional on all other parameters. If emission lines are
being marginalized out, they are excluded from the least-squares fit.
:returns cal:
A polynomial given by :math:`\sum_{m=0}^M a_{m} * T_m(x)`.
Expand All @@ -670,11 +669,10 @@ def spec_calibration(self, theta=None, obs=None, spec=None, **kwargs):
self.set_parameters(theta)

# norm = self.params.get('spec_norm', 1.0)
polyopt = ((self.params.get('polyorder', 0) > 0) &
order = np.squeeze(self.params.get('polyorder', 0))
polyopt = ((order > 0) &
(obs.get('spectrum', None) is not None))
if polyopt:
order = self.params['polyorder']

# generate mask
# remove region around emission lines if doing analytical marginalization
mask = obs.get('mask', np.ones_like(obs['wavelength'], dtype=bool)).copy()
Expand Down Expand Up @@ -1225,10 +1223,10 @@ def spec_calibration(self, theta=None, obs=None, **kwargs):
self.set_parameters(theta)

norm = self.params.get('spec_norm', 1.0)
polyopt = ((self.params.get('polyorder', 0) > 0) &
order = np.squeeze(self.params.get('polyorder', 0))
polyopt = ((order > 0) &
(obs.get('spectrum', None) is not None))
if polyopt:
order = self.params['polyorder']
mask = obs.get('mask', slice(None))
# map unmasked wavelengths to the interval -1, 1
# masked wavelengths may have x>1, x<-1
Expand Down

0 comments on commit dd50479

Please sign in to comment.