Skip to content

Commit 186fa22

Browse files
committed
Add test for loader priorities
1 parent 5885ed9 commit 186fa22

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

specutils/io/registers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,9 @@ def wrapper(*args, **kwargs):
126126
def custom_writer(label, dtype=Spectrum1D, priority=0):
127127
def decorator(func):
128128
if _astropy_has_priorities():
129-
io_registry.register_writer(
130-
label, Spectrum1D, func, priority=priority,
131-
)
129+
io_registry.register_writer(label, dtype, func, priority=priority)
132130
else:
133-
io_registry.register_writer(label, Spectrum1D, func)
131+
io_registry.register_writer(label, dtype, func)
134132

135133
@wraps(func)
136134
def wrapper(*args, **kwargs):

specutils/tests/test_io.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
This module tests SpecUtils io routines
55
"""
66

7+
from collections import Counter
78
from specutils.io.parsing_utils import generic_spectrum_from_table # or something like that
89
from astropy.io import registry
910
from astropy.table import Table
@@ -15,7 +16,8 @@
1516
import warnings
1617

1718
from specutils import Spectrum1D, SpectrumList
18-
from specutils.io import data_loader
19+
from specutils.io import data_loader, custom_writer
20+
from specutils.io.registers import _astropy_has_priorities
1921

2022

2123
def test_generic_spectrum_from_table(recwarn):
@@ -156,3 +158,46 @@ def reader(*args, **kwargs):
156158
# Clean up after ourselves
157159
registry.unregister_reader(format_name, datatype)
158160
registry.unregister_identifier(format_name, datatype)
161+
162+
163+
@pytest.mark.xfail(
164+
not _astropy_has_priorities(),
165+
reason="Test requires priorities to be implemented in astropy",
166+
raises=registry.IORegistryError,
167+
)
168+
def test_loader_uses_priority(tmpdir):
169+
counter = Counter()
170+
fname = str(tmpdir.join('good.txt'))
171+
172+
with open(fname, 'w') as ff:
173+
ff.write('\n')
174+
175+
def identifier(origin, *args, **kwargs):
176+
fname = args[0]
177+
return 'good' in fname
178+
179+
@data_loader("test_counting_loader1", identifier=identifier, priority=1)
180+
def counting_loader1(*args, **kwargs):
181+
counter["test1"] += 1
182+
return Spectrum1D(
183+
spectral_axis=np.arange(1,1.1,0.01)*u.AA,
184+
flux=np.ones(len(wave))*1.e-14*u.Jy,
185+
)
186+
187+
@data_loader("test_counting_loader2", identifier=identifier, priority=2)
188+
def counting_loader2(*args, **kwargs):
189+
counter["test2"] += 1
190+
return Spectrum1D(
191+
spectral_axis=np.arange(1,1.1,0.01)*u.AA,
192+
flux=np.ones(len(wave))*1.e-14*u.Jy,
193+
)
194+
195+
Spectrum1D.read(fname)
196+
assert counter["test2"] == 1
197+
assert counter["test1"] == 0
198+
199+
for datatype in [Spectrum1D, SpectrumList]:
200+
registry.unregister_reader("test_counting_loader1", datatype)
201+
registry.unregister_identifier("test_counting_loader1", datatype)
202+
registry.unregister_reader("test_counting_loader2", datatype)
203+
registry.unregister_identifier("test_counting_loader2", datatype)

0 commit comments

Comments
 (0)