From 6b648a62f9d9b4b1a0f8fe6efdbbaf8a04cab583 Mon Sep 17 00:00:00 2001 From: James Tocknell Date: Wed, 13 Jan 2021 13:36:39 +1100 Subject: [PATCH 1/2] Add support for astropy registry priorities --- specutils/io/registers.py | 45 ++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/specutils/io/registers.py b/specutils/io/registers.py index f7f942d87..40f720d36 100644 --- a/specutils/io/registers.py +++ b/specutils/io/registers.py @@ -1,6 +1,7 @@ """ A module containing the mechanics of the specutils io registry. """ +import inspect import os import pathlib import sys @@ -16,6 +17,16 @@ log = logging.getLogger(__name__) +def _astropy_has_priorities(): + """ + Check if astropy has support for loader priorities + """ + sig = inspect.signature(io_registry.register_reader) + if sig.parameters.get("priority") is not None: + return True + return False + + def data_loader(label, identifier=None, dtype=Spectrum1D, extensions=None, priority=0): """ @@ -52,7 +63,10 @@ def wrapper(*args, **kwargs): return wrapper def decorator(func): - io_registry.register_reader(label, dtype, func) + if _astropy_has_priorities(): + io_registry.register_reader(label, dtype, func, priority=priority) + else: + io_registry.register_reader(label, dtype, func) if identifier is None: # If the identifier is not defined, but the extensions are, create @@ -78,17 +92,6 @@ def decorator(func): # Include the file extensions as attributes on the function object func.extensions = extensions - # Include priority on the loader function attribute - func.priority = priority - - # Sort the io_registry based on priority - sorted_loaders = sorted(io_registry._readers.items(), - key=lambda item: getattr(item[1], 'priority', 0)) - - # Update the registry with the sorted dictionary - io_registry._readers.clear() - io_registry._readers.update(sorted_loaders) - log.debug("Successfully loaded reader \"{}\".".format(label)) # Automatically register a SpectrumList reader for any data_loader that @@ -102,7 +105,14 @@ def load_spectrum_list(*args, **kwargs): load_spectrum_list.extensions = extensions load_spectrum_list.priority = priority - io_registry.register_reader(label, SpectrumList, load_spectrum_list) + if _astropy_has_priorities(): + io_registry.register_reader( + label, SpectrumList, load_spectrum_list, priority=priority, + ) + else: + io_registry.register_reader( + label, SpectrumList, load_spectrum_list, + ) io_registry.register_identifier(label, SpectrumList, id_func) log.debug("Created SpectrumList reader for \"{}\".".format(label)) @@ -113,9 +123,14 @@ def wrapper(*args, **kwargs): return decorator -def custom_writer(label, dtype=Spectrum1D): +def custom_writer(label, dtype=Spectrum1D, priority=0): def decorator(func): - io_registry.register_writer(label, Spectrum1D, func) + if _astropy_has_priorities(): + io_registry.register_writer( + label, Spectrum1D, func, priority=priority, + ) + else: + io_registry.register_writer(label, Spectrum1D, func) @wraps(func) def wrapper(*args, **kwargs): From 2ff82ffc6329a0a27805ad9afd649bed87c3a605 Mon Sep 17 00:00:00 2001 From: James Tocknell Date: Tue, 15 Jun 2021 21:31:44 +1000 Subject: [PATCH 2/2] Add test for loader priorities --- specutils/io/registers.py | 6 ++--- specutils/tests/test_io.py | 49 +++++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/specutils/io/registers.py b/specutils/io/registers.py index 40f720d36..3559d0745 100644 --- a/specutils/io/registers.py +++ b/specutils/io/registers.py @@ -126,11 +126,9 @@ def wrapper(*args, **kwargs): def custom_writer(label, dtype=Spectrum1D, priority=0): def decorator(func): if _astropy_has_priorities(): - io_registry.register_writer( - label, Spectrum1D, func, priority=priority, - ) + io_registry.register_writer(label, dtype, func, priority=priority) else: - io_registry.register_writer(label, Spectrum1D, func) + io_registry.register_writer(label, dtype, func) @wraps(func) def wrapper(*args, **kwargs): diff --git a/specutils/tests/test_io.py b/specutils/tests/test_io.py index 7704ea51e..0ec3559b3 100644 --- a/specutils/tests/test_io.py +++ b/specutils/tests/test_io.py @@ -4,6 +4,7 @@ This module tests SpecUtils io routines """ +from collections import Counter from specutils.io.parsing_utils import generic_spectrum_from_table # or something like that from astropy.io import registry from astropy.table import Table @@ -15,7 +16,8 @@ import warnings from specutils import Spectrum1D, SpectrumList -from specutils.io import data_loader +from specutils.io import data_loader, custom_writer +from specutils.io.registers import _astropy_has_priorities def test_generic_spectrum_from_table(recwarn): @@ -156,3 +158,48 @@ def reader(*args, **kwargs): # Clean up after ourselves registry.unregister_reader(format_name, datatype) registry.unregister_identifier(format_name, datatype) + + +@pytest.mark.xfail( + not _astropy_has_priorities(), + reason="Test requires priorities to be implemented in astropy", + raises=registry.IORegistryError, +) +def test_loader_uses_priority(tmpdir): + counter = Counter() + fname = str(tmpdir.join('good.txt')) + + with open(fname, 'w') as ff: + ff.write('\n') + + def identifier(origin, *args, **kwargs): + fname = args[0] + return 'good' in fname + + @data_loader("test_counting_loader1", identifier=identifier, priority=1) + def counting_loader1(*args, **kwargs): + counter["test1"] += 1 + wave = np.arange(1,1.1,0.01)*u.AA + return Spectrum1D( + spectral_axis=wave, + flux=np.ones(len(wave))*1.e-14*u.Jy, + ) + + @data_loader("test_counting_loader2", identifier=identifier, priority=2) + def counting_loader2(*args, **kwargs): + counter["test2"] += 1 + wave = np.arange(1,1.1,0.01)*u.AA + return Spectrum1D( + spectral_axis=wave, + flux=np.ones(len(wave))*1.e-14*u.Jy, + ) + + Spectrum1D.read(fname) + assert counter["test2"] == 1 + assert counter["test1"] == 0 + + for datatype in [Spectrum1D, SpectrumList]: + registry.unregister_reader("test_counting_loader1", datatype) + registry.unregister_identifier("test_counting_loader1", datatype) + registry.unregister_reader("test_counting_loader2", datatype) + registry.unregister_identifier("test_counting_loader2", datatype)