Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions specutils/io/registers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
A module containing the mechanics of the specutils io registry.
"""
import inspect
import os
import pathlib
import sys
Expand All @@ -16,6 +17,16 @@
log = logging.getLogger(__name__)


def _astropy_has_priorities():
"""
Check if astropy has support for loader priorities
"""
sig = inspect.signature(io_registry.register_reader)
if sig.parameters.get("priority") is not None:
return True
return False


def data_loader(label, identifier=None, dtype=Spectrum1D, extensions=None,
priority=0):
"""
Expand Down Expand Up @@ -52,7 +63,10 @@ def wrapper(*args, **kwargs):
return wrapper

def decorator(func):
io_registry.register_reader(label, dtype, func)
if _astropy_has_priorities():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check really necessary? If a user-defined reader/writer has no priority, seems we can just default to the lowest priority instead of resorting to reflection?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Older versions of io_registry.register_reader won't accept the additional argument, so if we don't do this check, we'll get exceptions when trying to add a priority. Once all supported versions of astropy accept priorities, we can drop the check.

io_registry.register_reader(label, dtype, func, priority=priority)
else:
io_registry.register_reader(label, dtype, func)

if identifier is None:
# If the identifier is not defined, but the extensions are, create
Expand All @@ -78,17 +92,6 @@ def decorator(func):
# Include the file extensions as attributes on the function object
func.extensions = extensions

# Include priority on the loader function attribute
func.priority = priority

# Sort the io_registry based on priority
sorted_loaders = sorted(io_registry._readers.items(),
key=lambda item: getattr(item[1], 'priority', 0))

# Update the registry with the sorted dictionary
io_registry._readers.clear()
io_registry._readers.update(sorted_loaders)

log.debug("Successfully loaded reader \"{}\".".format(label))

# Automatically register a SpectrumList reader for any data_loader that
Expand All @@ -102,7 +105,14 @@ def load_spectrum_list(*args, **kwargs):
load_spectrum_list.extensions = extensions
load_spectrum_list.priority = priority

io_registry.register_reader(label, SpectrumList, load_spectrum_list)
if _astropy_has_priorities():
io_registry.register_reader(
label, SpectrumList, load_spectrum_list, priority=priority,
)
else:
io_registry.register_reader(
label, SpectrumList, load_spectrum_list,
)
io_registry.register_identifier(label, SpectrumList, id_func)
log.debug("Created SpectrumList reader for \"{}\".".format(label))

Expand All @@ -113,9 +123,12 @@ def wrapper(*args, **kwargs):
return decorator


def custom_writer(label, dtype=Spectrum1D):
def custom_writer(label, dtype=Spectrum1D, priority=0):
def decorator(func):
io_registry.register_writer(label, Spectrum1D, func)
if _astropy_has_priorities():
io_registry.register_writer(label, dtype, func, priority=priority)
else:
io_registry.register_writer(label, dtype, func)

@wraps(func)
def wrapper(*args, **kwargs):
Expand Down
49 changes: 48 additions & 1 deletion specutils/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
This module tests SpecUtils io routines
"""

from collections import Counter
from specutils.io.parsing_utils import generic_spectrum_from_table # or something like that
from astropy.io import registry
from astropy.table import Table
Expand All @@ -15,7 +16,8 @@
import warnings

from specutils import Spectrum1D, SpectrumList
from specutils.io import data_loader
from specutils.io import data_loader, custom_writer
from specutils.io.registers import _astropy_has_priorities


def test_generic_spectrum_from_table(recwarn):
Expand Down Expand Up @@ -156,3 +158,48 @@ def reader(*args, **kwargs):
# Clean up after ourselves
registry.unregister_reader(format_name, datatype)
registry.unregister_identifier(format_name, datatype)


@pytest.mark.xfail(
not _astropy_has_priorities(),
reason="Test requires priorities to be implemented in astropy",
raises=registry.IORegistryError,
)
def test_loader_uses_priority(tmpdir):
counter = Counter()
fname = str(tmpdir.join('good.txt'))

with open(fname, 'w') as ff:
ff.write('\n')

def identifier(origin, *args, **kwargs):
fname = args[0]
return 'good' in fname

@data_loader("test_counting_loader1", identifier=identifier, priority=1)
def counting_loader1(*args, **kwargs):
counter["test1"] += 1
wave = np.arange(1,1.1,0.01)*u.AA
return Spectrum1D(
spectral_axis=wave,
flux=np.ones(len(wave))*1.e-14*u.Jy,
)

@data_loader("test_counting_loader2", identifier=identifier, priority=2)
def counting_loader2(*args, **kwargs):
counter["test2"] += 1
wave = np.arange(1,1.1,0.01)*u.AA
return Spectrum1D(
spectral_axis=wave,
flux=np.ones(len(wave))*1.e-14*u.Jy,
)

Spectrum1D.read(fname)
assert counter["test2"] == 1
assert counter["test1"] == 0

for datatype in [Spectrum1D, SpectrumList]:
registry.unregister_reader("test_counting_loader1", datatype)
registry.unregister_identifier("test_counting_loader1", datatype)
registry.unregister_reader("test_counting_loader2", datatype)
registry.unregister_identifier("test_counting_loader2", datatype)