diff --git a/ci/environment.yml b/ci/environment.yml index d70d3c28..bd0b9458 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -6,6 +6,7 @@ dependencies: - cftime - codecov - dask >=2024.12 + - esmvalcore >=2.0.0 - fastprogress >=1.0.0 - flaky >=3.8.0 - fsspec >=2024.12 diff --git a/intake_esm/__init__.py b/intake_esm/__init__.py index 1c8a03ac..e7e3d83d 100644 --- a/intake_esm/__init__.py +++ b/intake_esm/__init__.py @@ -4,11 +4,43 @@ # Import intake first to avoid circular imports during discovery. import intake +import importlib from intake_esm import tutorial from intake_esm.core import esm_datastore from intake_esm.derived import DerivedVariableRegistry, default_registry from intake_esm.utils import set_options, show_versions +from intake_esm._imports import _to_opt_import_flag, _from_opt_import_flag +from intake_esm import _imports as _import_module from intake_esm._version import __version__ + +import_flags = [_to_opt_import_flag(name) for name in _import_module._optional_imports] + +__all__ = [ + 'esm_datastore', + 'DerivedVariableRegistry', + 'default_registry', + 'set_options', + 'show_versions', + 'tutorial', + '__version__', +] + import_flags + + +def __getattr__(attr: str) -> object: + """ + Lazy load optional imports. + """ + + if attr in (gl := globals()): + return gl[attr] + + try: + return getattr(_import_module, attr) + except AttributeError: + raise AttributeError( + f"Module '{__name__}' has no attribute '{attr}'. " + f'Did you mean one of {", ".join(import_flags)}?' + ) diff --git a/intake_esm/_imports.py b/intake_esm/_imports.py new file mode 100644 index 00000000..155bfd76 --- /dev/null +++ b/intake_esm/_imports.py @@ -0,0 +1,41 @@ +import importlib + +_optional_imports: dict[str, bool | None] = {'esmvalcore': None} + + +def _to_opt_import_flag(name: str) -> str: + """Dynamically create import flags for optional imports.""" + return f'_{name.upper()}_AVAILABLE' + + +def _from_opt_import_flag(name: str) -> str: + """Dynamically retrive the optional import name from its flag.""" + if name.startswith('_') and name.endswith('_AVAILABLE'): + return name[1:-10].lower() + raise ValueError( + f"Invalid optional import flag '{name}'. Expected format: '__AVAILABLE'." + ) + + +def __getattr__(attr: str) -> object: + """ + Lazy load optional imports. + """ + + if attr in (gl := globals()): + return gl[attr] + + import_flags = [_to_opt_import_flag(name) for name in _optional_imports] + + if attr in import_flags: + import_name = _from_opt_import_flag(attr) + if _optional_imports.get(import_name, None) is None: + _optional_imports[import_name] = bool(importlib.util.find_spec(import_name)) + return _optional_imports[import_name] + else: + return _optional_imports[import_name] + + raise AttributeError( + f"Module '{__name__}' has no attribute '{attr}'. " + f'Did you mean one of {", ".join(import_flags)}?' + ) diff --git a/intake_esm/cat.py b/intake_esm/cat.py index 8f9e5c83..084ef2cc 100644 --- a/intake_esm/cat.py +++ b/intake_esm/cat.py @@ -114,6 +114,7 @@ class ESMCatalogModel(pydantic.BaseModel): id: str = '' catalog_dict: list[dict] | None = None catalog_file: pydantic.StrictStr | None = None + fhandle: pydantic.StrictStr | None = None description: pydantic.StrictStr | None = None title: pydantic.StrictStr | None = None last_updated: datetime.datetime | datetime.date | None = None @@ -269,6 +270,7 @@ def load( df=pl.DataFrame(cat.catalog_dict).to_pandas(), ) + cat.fhandle = json_file cat._cast_agg_columns_with_iterables() return cat @@ -496,6 +498,37 @@ def validate_query(cls, model): model.query = _query return model + def _extend_search_history( + cls, search_hist: list[dict[str, typing.Any]] + ) -> list[dict[str, typing.Any]]: + """ + Extend the search history with the current query. Note this doesn't yet + handle cases where we have set `require_all_on`. + + Parameters + ---------- + search_hist : list[dict] + The current search history. + query : QueryModel + The current query to be added to the search history. + + Returns + ------- + list[dict[str, typing.Any]] + The updated search history. + """ + + _query = cls.query + + if not _query: + search_hist.append({}) + return search_hist + + for colname, search_terms in _query.items(): + search_hist.append({colname: search_terms}) + + return search_hist + class FramesModel(pydantic.BaseModel): """A Pydantic model to represent our collection of dataframes - pandas, polars, diff --git a/intake_esm/core.py b/intake_esm/core.py index ce71e094..2608ed93 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -5,6 +5,10 @@ import warnings from copy import deepcopy +if typing.TYPE_CHECKING: + import esmvalcore + import esmvalcore.dataset + import dask import packaging.version import xarray as xr @@ -24,7 +28,8 @@ from fastprogress.fastprogress import progress_bar from intake.catalog import Catalog -from .cat import ESMCatalogModel +from ._imports import _ESMVALCORE_AVAILABLE +from .cat import ESMCatalogModel, QueryModel from .derived import DerivedVariableRegistry, default_registry from .source import ESMDataSource from .utils import MinimalExploder @@ -397,6 +402,16 @@ def __dir__(self) -> list[str]: def _ipython_key_completions_(self): return self.__dir__() + @property + def search_history(self) -> list[dict[str, typing.Any]]: + """Return the search history for the catalog.""" + + try: + return self._search_history + except AttributeError: + self._search_history: list[dict[str, typing.Any]] = [] + return self._search_history + @pydantic.validate_call def search( self, @@ -458,6 +473,14 @@ def search( 4 landCoverFrac """ + _search_hist = ( + query + if isinstance(query, QueryModel) + else QueryModel( + query=query, require_all_on=require_all_on, columns=self.df.columns.tolist() + )._extend_search_history(self.search_history) + ) + # step 1: Search in the base/main catalog esmcat_results = self.esmcat.search(require_all_on=require_all_on, query=query) @@ -507,6 +530,8 @@ def search( cat.derivedcat._registry.update(derived_cat_subset) else: cat.derivedcat = self.derivedcat + + cat._search_history = _search_hist return cat @pydantic.validate_call @@ -893,6 +918,74 @@ def to_dask(self, **kwargs) -> xr.Dataset: _, ds = res.popitem() return ds + def to_esmvalcore( + self, + cmorizer: typing.Any | None = None, + **kwargs, + ) -> 'esmvalcore.dataset.Dataset': + """ + Convert result to an ESMValCore Dataset. + + This is only possible if the search returned exactly one result. + + Parameters + ---------- + facet_map: dict[FacetValue, str] + Mapping of ESMValCore Dataset facets to their corresponding esm_datastore + attributes. For example, the mapping for a dataset containing keys + 'activity_id', 'source_id', 'member_id', 'experiment_id' would look like: + ```python + facets = { + "activity": "activity_id", + "dataset": "source_id", + "ensemble": "member_id", + "exp": "experiment_id", + "grid": "grid_label", + }, + ``` + cmorize: Any, optional + CMORizer to use in order to CMORize the datastore search results for + the ESMValCore Dataset. Presumably this will be a callable? If not set, + no CMORization will be done. + kwargs: dict + TBC. + """ + if not _ESMVALCORE_AVAILABLE: + raise ImportError( + '`to_esmvalcore()` requires the esmvalcore package to be installed. ' + 'To proceed please install esmvalcore using: ' + ' `python -m pip install esmvalcore` or `conda install -c conda-forge esmvalcore`.' + ) + + if len(self) != 1: # quick check to fail more quickly if there are many results + raise ValueError( + f'Expected exactly one dataset. Received {len(self)} datasets. Please refine your search.' + ) + + # Use esmvalcore to load the intake configuration & work out how we + # need to map our facets + + from esmvalcore.config._intake import _read_facets, load_intake_config + from esmvalcore.data import merge_intake_search_history as merge_search_history + from esmvalcore.dataset import Dataset + + facet_map, project = _read_facets(load_intake_config(), self.esmcat.fhandle) + + search = merge_search_history(self.search_history) + + facets = {k: search.get(v) for k, v in facet_map.items()} + facets = {k: v for k, v in facets.items() if v is not None} + + facets.pop('version', None) # If there's a version, chuck it + facets['project'] = project + + ds = Dataset(**facets) + + ds.files = self.unique().path + ds.augment_facets() + + return ds + def _create_derived_variables(self, datasets, skip_on_error): if len(self.derivedcat) > 0: datasets = self.derivedcat.update_datasets( diff --git a/tests/conftest.py b/tests/conftest.py index 6e45a660..b506712f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,8 @@ import pytest +import intake_esm + here = os.path.abspath(os.path.dirname(__file__)) @@ -13,3 +15,14 @@ def sample_cmip6(): @pytest.fixture def sample_bad_input(): return os.path.join(here, 'sample-catalogs/bad.json') + + +@pytest.fixture +def cleanup_init(): + """ + This resets the _optional_imports dictionary in intake_esm to it's default + state before & after tests that use it so we can test lazy loading and whatnot + """ + intake_esm._imports._optional_imports = {'esmvalcore': None} + yield + intake_esm._imports._optional_imports = {'esmvalcore': None} diff --git a/tests/test_core.py b/tests/test_core.py index 31469228..57b62c32 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -258,10 +258,55 @@ def test_catalog_contains(): ) def test_catalog_search(path, query, expected_size): cat = intake.open_esm_datastore(path) + assert cat.search_history == [] new_cat = cat.search(**query) assert len(new_cat) == expected_size +@pytest.mark.parametrize( + 'path, query, expected', + [ + (cdf_cat_sample_cesmle, {'experiment': 'CTRL'}, [{'experiment': ['CTRL']}]), + (cdf_cat_sample_cesmle, {'experiment': ['CTRL', '20C']}, [{'experiment': ['CTRL', '20C']}]), + (cdf_cat_sample_cesmle, {}, [{}]), + ( + cdf_cat_sample_cesmle, + {'variable': 'SHF', 'time_range': ['200601-210012']}, + [{'variable': ['SHF']}, {'time_range': ['200601-210012']}], + ), + ], +) +def test_catalog_search_history(path, query, expected): + cat = intake.open_esm_datastore(path) + assert cat.search_history == [] + new_cat = cat.search(**query) + assert new_cat.search_history == expected + + +@pytest.mark.parametrize( + 'path, queries, expected', + [ + (cdf_cat_sample_cesmle, [{'experiment': 'CTRL'}, {}], [{'experiment': ['CTRL']}, {}]), + ( + cdf_cat_sample_cesmle, + [{'variable': 'SHF'}, {'time_range': ['200601-210012']}], + [{'variable': ['SHF']}, {'time_range': ['200601-210012']}], + ), + ( + cdf_cat_sample_cesmle, + [{'experiment': ['CTRL', '20C']}, {'variable': 'SHF'}], + [{'experiment': ['CTRL', '20C']}, {'variable': ['SHF']}], + ), + ], +) +def test_catalog_search_history_sequential(path, queries, expected): + cat = intake.open_esm_datastore(path) + assert cat.search_history == [] + q1, q2 = queries + new_cat = cat.search(**q1).search(**q2) + assert new_cat.search_history == expected + + @pytest.mark.parametrize( 'path, columns_with_iterables, query, expected_size', [ @@ -704,3 +749,33 @@ def test__get_threaded(mock_get_env, threaded, ITK_ESM_THREADING, expected): intake_esm.core._get_threaded(threaded) else: assert intake_esm.core._get_threaded(threaded) == expected + + +@mock.patch('intake_esm.core._ESMVALCORE_AVAILABLE', False) +def test_to_esmvalcore_unavailable(): + cat = intake.open_esm_datastore(zarr_cat_pangeo_cmip6) + cat_sub = cat.search( + **dict( + variable_id=['pr'], + experiment_id='ssp370', + activity_id='AerChemMIP', + source_id='BCC-ESM1', + table_id='Amon', + grid_label='gn', + ) + ) + with pytest.raises(ImportError, match=r'`to_esmvalcore\(\)` requires the esmvalcore package'): + _ = cat_sub.to_esmvalcore( + search=dict( + variable_id=['pr'], + experiment_id='ssp370', + activity_id='AerChemMIP', + source_id='BCC-ESM1', + table_id='Amon', + grid_label='gn', + ), + xarray_open_kwargs={ + 'consolidated': True, + 'backend_kwargs': {'storage_options': {'token': 'anon'}}, + }, + ) diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 00000000..ecceb26f --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,104 @@ +from unittest import mock + +import pytest + + +def test__all__(cleanup_init): + import intake_esm + + assert intake_esm.__all__ == [ + 'esm_datastore', + 'DerivedVariableRegistry', + 'default_registry', + 'set_options', + 'show_versions', + 'tutorial', + '__version__', + '_ESMVALCORE_AVAILABLE', + ] + + +def test__to_optional_import_flag(cleanup_init): + from intake_esm import _to_opt_import_flag + + assert _to_opt_import_flag('esmvalcore') == '_ESMVALCORE_AVAILABLE' + # This looks stupid but we need to be careful re. underscores + assert _to_opt_import_flag('intake_esm') == '_INTAKE_ESM_AVAILABLE' + + +def test__from_optional_import_flag(cleanup_init): + from intake_esm import _from_opt_import_flag + + assert _from_opt_import_flag('_ESMVALCORE_AVAILABLE') == 'esmvalcore' + # This looks stupid but we need to be careful re. underscores + assert _from_opt_import_flag('_INTAKE_ESM_AVAILABLE') == 'intake_esm' + + +@pytest.mark.parametrize( + 'str', + [ + '_ESMVALCORE_AVAILABLE', + '_INTAKE_ESM_AVAILABLE', + ], +) +def test__opt_import_flags(cleanup_init, str): + from intake_esm import _from_opt_import_flag, _to_opt_import_flag + + assert _to_opt_import_flag(_from_opt_import_flag(str)) == str + + +@pytest.mark.parametrize( + 'str', + [ + '_INVALID_FLAG', + 'invalid_flag', + 'INVALID_FLAG_AVAILABLE', + ], +) +def test__opt_import_flags_invalid(cleanup_init, str): + from intake_esm import _from_opt_import_flag + + with pytest.raises(ValueError): + _from_opt_import_flag(str) + + +@pytest.mark.parametrize( + 'str', + [ + 'esmvalcore', + 'intake_esm', + ], +) +def test_rev_opt_import_flags(cleanup_init, str): + from intake_esm import _from_opt_import_flag, _to_opt_import_flag + + assert _from_opt_import_flag(_to_opt_import_flag(str)) == str + + +def test_getattr_random_attr_fail(cleanup_init): + import intake_esm + + with pytest.raises(AttributeError, match="Module 'intake_esm' has no attribute 'random_attr'"): + _ = intake_esm.random_attr + + +@mock.patch('importlib.util.find_spec', return_value=False) +def test_getattr_optional_import(mock_fnd_spec, cleanup_init): + import intake_esm + + assert intake_esm._optional_imports == {'esmvalcore': None} + + assert intake_esm._ESMVALCORE_AVAILABLE is False + assert intake_esm._optional_imports == {'esmvalcore': False} + + +@mock.patch('importlib.util.find_spec', return_value=True) +def test_getattr_caching(mock_find_spec, cleanup_init): + import intake_esm + + # Simulate the first call to find_spec + mock_find_spec.return_value = True + assert intake_esm._ESMVALCORE_AVAILABLE is True + assert intake_esm._ESMVALCORE_AVAILABLE is True + + mock_find_spec.assert_called_once_with('esmvalcore')