From b800215742c63230dad3aff792e8295dff6fd0e4 Mon Sep 17 00:00:00 2001 From: Benjamin Johnson Date: Sun, 19 Jun 2022 09:46:34 -0400 Subject: [PATCH] allow use of list of filters or filterset. --- prospect/models/sedmodel.py | 2 +- tests/test_eline.py | 32 ++++++++++++++++++++------------ tests/test_predict.py | 1 - 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/prospect/models/sedmodel.py b/prospect/models/sedmodel.py index a3b04b66..2230fcbd 100644 --- a/prospect/models/sedmodel.py +++ b/prospect/models/sedmodel.py @@ -326,7 +326,7 @@ def nebline_photometry(self, filterset, elams=None, elums=None): # faster way to look up the transmission than the later loop flist = filterset.filters except(AttributeError): - flist = filters + flist = filterset for i, filt in enumerate(flist): # calculate transmission at line wavelengths trans = np.interp(elams, filt.wavelength, filt.transmission, diff --git a/tests/test_eline.py b/tests/test_eline.py index 9882e71e..4b3552ef 100644 --- a/tests/test_eline.py +++ b/tests/test_eline.py @@ -3,16 +3,22 @@ import numpy as np +import pytest + from sedpy import observate -from prospect import prospect_args from prospect.data import Photometry, Spectrum, from_oldstyle from prospect.models.templates import TemplateLibrary from prospect.models.sedmodel import SpecModel - from prospect.sources import CSPSpecBasis +@pytest.fixture +def get_sps(): + sps = CSPSpecBasis(zcontinuous=1) + return sps + + # test nebular line specification def test_eline_parsing(): model_pars = TemplateLibrary["parametric_sfh"] @@ -53,7 +59,7 @@ def test_eline_parsing(): assert model._fit_eline.sum() == (len(model._use_eline) - len(fix_lines)) -def test_nebline_phot_addition(): +def test_nebline_phot_addition(get_sps): fnames = [f"sdss_{b}0" for b in "ugriz"] filts = observate.load_filters(fnames) @@ -61,9 +67,10 @@ def test_nebline_phot_addition(): wavelength=np.linspace(3000, 9000, 1000), spectrum=np.ones(1000), unc=np.ones(1000)*0.1) - obslist = from_oldstyle(obs) + sdat, pdat = from_oldstyle(obs) + obslist = [sdat, pdat] - sps = CSPSpecBasis(zcontinuous=1) + sps = get_sps # Make emission lines more prominent zred = 1.0 @@ -85,14 +92,14 @@ def test_nebline_phot_addition(): (s2, p2), _ = m2.predict(m2.theta, obslist, sps) # make sure some of the lines were important - p1n = m1.nebline_photometry(obslist[-1].filterset) + p1n = m1.nebline_photometry(filts) assert np.any(p1n / p1[1] > 0.05) # make sure you got the same-ish answer assert np.all((np.abs(p1 - p2) / p1) < 1e-2) -def test_filtersets(): +def test_filtersets(get_sps): """This test no longer relevant..... """ fnames = [f"sdss_{b}0" for b in "ugriz"] @@ -102,9 +109,10 @@ def test_filtersets(): spectrum=np.ones(1000), unc=np.ones(1000)*0.1, filters=fnames) - obslist = from_oldstyle(obs) + sdat, pdat = from_oldstyle(obs) + obslist = [sdat, pdat] - sps = CSPSpecBasis(zcontinuous=1) + sps = get_sps # Make emission lines more prominent zred = 0.5 @@ -128,7 +136,7 @@ def test_filtersets(): # make sure some of the filters are affected by lines # ( nebular flux > 10% of total flux) if i == 1: - nebphot = model.nebline_photometry(flist) + nebphot = model.nebline_photometry(pdat.filterset) assert np.any(nebphot / pset > 0.1) # make sure photometry is consistent @@ -136,7 +144,7 @@ def test_filtersets(): # We always use filtersets now -def test_eline_implementation(): +def test_eline_implementation(get_sps): test_eline_parsing() @@ -156,7 +164,7 @@ def test_eline_implementation(): model_pars["zred"]["init"] = 4 model = SpecModel(model_pars) - sps = CSPSpecBasis(zcontinuous=1) + sps = get_sps # generate with all fixed lines added (spec, phot), mfrac = model.predict(model.theta, obslist, sps=sps) diff --git a/tests/test_predict.py b/tests/test_predict.py index 6fafd053..73ff5c16 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -47,7 +47,6 @@ def build_obs(multispec=True): return obslist -#@pytest.mark.skip(reason="not ready") def test_prediction_nodata(build_sps): sps = build_sps model = build_model(add_neb=True)