|
4 | 4 | This module tests SpecUtils io routines |
5 | 5 | """ |
6 | 6 |
|
| 7 | +from collections import Counter |
7 | 8 | from specutils.io.parsing_utils import generic_spectrum_from_table # or something like that |
8 | 9 | from astropy.io import registry |
9 | 10 | from astropy.table import Table |
|
15 | 16 | import warnings |
16 | 17 |
|
17 | 18 | from specutils import Spectrum1D, SpectrumList |
18 | | -from specutils.io import data_loader |
| 19 | +from specutils.io import data_loader, custom_writer |
| 20 | +from specutils.io.registers import _astropy_has_priorities |
19 | 21 |
|
20 | 22 |
|
21 | 23 | def test_generic_spectrum_from_table(recwarn): |
@@ -156,3 +158,46 @@ def reader(*args, **kwargs): |
156 | 158 | # Clean up after ourselves |
157 | 159 | registry.unregister_reader(format_name, datatype) |
158 | 160 | registry.unregister_identifier(format_name, datatype) |
| 161 | + |
| 162 | + |
| 163 | +@pytest.mark.xfail( |
| 164 | + not _astropy_has_priorities(), |
| 165 | + reason="Test requires priorities to be implemented in astropy", |
| 166 | + raises=registry.IORegistryError, |
| 167 | +) |
| 168 | +def test_loader_uses_priority(tmpdir): |
| 169 | + counter = Counter() |
| 170 | + fname = str(tmpdir.join('good.txt')) |
| 171 | + |
| 172 | + with open(fname, 'w') as ff: |
| 173 | + ff.write('\n') |
| 174 | + |
| 175 | + def identifier(origin, *args, **kwargs): |
| 176 | + fname = args[0] |
| 177 | + return 'good' in fname |
| 178 | + |
| 179 | + @data_loader("test_counting_loader1", identifier=identifier, priority=1) |
| 180 | + def counting_loader1(*args, **kwargs): |
| 181 | + counter["test1"] += 1 |
| 182 | + return Spectrum1D( |
| 183 | + spectral_axis=np.arange(1,1.1,0.01)*u.AA, |
| 184 | + flux=np.ones(len(wave))*1.e-14*u.Jy, |
| 185 | + ) |
| 186 | + |
| 187 | + @data_loader("test_counting_loader2", identifier=identifier, priority=2) |
| 188 | + def counting_loader2(*args, **kwargs): |
| 189 | + counter["test2"] += 1 |
| 190 | + return Spectrum1D( |
| 191 | + spectral_axis=np.arange(1,1.1,0.01)*u.AA, |
| 192 | + flux=np.ones(len(wave))*1.e-14*u.Jy, |
| 193 | + ) |
| 194 | + |
| 195 | + Spectrum1D.read(fname) |
| 196 | + assert counter["test2"] == 1 |
| 197 | + assert counter["test1"] == 0 |
| 198 | + |
| 199 | + for datatype in [Spectrum1D, SpectrumList]: |
| 200 | + registry.unregister_reader("test_counting_loader1", datatype) |
| 201 | + registry.unregister_identifier("test_counting_loader1", datatype) |
| 202 | + registry.unregister_reader("test_counting_loader2", datatype) |
| 203 | + registry.unregister_identifier("test_counting_loader2", datatype) |
0 commit comments