From 7dcf9acb6e0a5b127bf4eb8401c2600575ae1511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 16 Jun 2024 21:38:20 +0100 Subject: [PATCH 01/40] Create retrieval.py --- skpro/utils/retrieval.py | 96 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 skpro/utils/retrieval.py diff --git a/skpro/utils/retrieval.py b/skpro/utils/retrieval.py new file mode 100644 index 000000000..b9c7902cb --- /dev/null +++ b/skpro/utils/retrieval.py @@ -0,0 +1,96 @@ +import importlib +import inspect +import pkgutil +from functools import lru_cache + +EXCLUDE_MODULES_STARTING_WITH = ("all", "test", "contrib") + + +def _all_functions(module_name): + """Get all functions from a module, including submodules. + + Excludes modules starting with 'all' or 'test'. + + Parameters + ---------- + module_name : str + Name of the module. + + Returns + ------- + functions_list : list + List of tuples (function_name: str, function_object: function). + """ + # copy to avoid modifying the cache + return _all_cond(module_name, inspect.isfunction).copy() + + +def _all_classes(module_name): + """Get all classes from a module, including submodules. + + Excludes modules starting with 'all' or 'test'. + + Parameters + ---------- + module_name : str + Name of the module. + + Returns + ------- + classes_list : list + List of tuples (class_name: str, class_ref: class). + """ + # copy to avoid modifying the cache + return _all_cond(module_name, inspect.isclass).copy() + + +@lru_cache +def _all_cond(module_name, cond): + """Get all objects from a module satisfying a condition. + + The condition should be a hashable callable, + of signature ``condition(obj) -> bool``. + + Excludes modules starting with 'all' or 'test'. + + Parameters + ---------- + module_name : str + Name of the module. + cond : callable + Condition to satisfy. + Signature: ``condition(obj) -> bool``, + passed as predicate to ``inspect.getmembers``. + + Returns + ------- + functions_list : list + List of tuples (function_name, function_object). + """ + # Import the package + package = importlib.import_module(module_name) + + # Initialize an empty list to hold all objects + obj_list = [] + + # Walk through the package's modules + package_path = package.__path__[0] + for _, modname, _ in pkgutil.walk_packages( + path=[package_path], prefix=package.__name__ + "." + ): + # Skip modules starting with 'all' or 'test' + if modname.split(".")[-1].startswith(EXCLUDE_MODULES_STARTING_WITH): + continue + + # Import the module + module = importlib.import_module(modname) + + # Get all objects from the module + for name, obj in inspect.getmembers(module, cond): + # if object is imported from another module, skip it + if obj.__module__ != module.__name__: + continue + # add the object to the list + obj_list.append((name, obj)) + + return obj_list From 3657d940063a090b2e7bb9edc58c8a309868807d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Jun 2024 16:27:07 +0100 Subject: [PATCH 02/40] base cls --- skpro/datatypes/_base/__init__.py | 5 + skpro/datatypes/_base/_base.py | 201 ++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 skpro/datatypes/_base/__init__.py create mode 100644 skpro/datatypes/_base/_base.py diff --git a/skpro/datatypes/_base/__init__.py b/skpro/datatypes/_base/__init__.py new file mode 100644 index 000000000..e04eb34d9 --- /dev/null +++ b/skpro/datatypes/_base/__init__.py @@ -0,0 +1,5 @@ +"""Base module for datatypes.""" + +from sktime.datatypes._base._base import BaseConverter, BaseDatatype + +__all__ = ["BaseConverter", "BaseDatatype"] diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py new file mode 100644 index 000000000..fe663c370 --- /dev/null +++ b/skpro/datatypes/_base/_base.py @@ -0,0 +1,201 @@ +# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +"""Base class for data types.""" + +__author__ = ["fkiraly"] + +from skpro.base import BaseObject +from skpro.datatypes._common import _ret +from skpro.utils.deep_equals import deep_equals + + +class BaseDatatype(BaseObject): + """Base class for data types. + + This class is the base class for all data types in sktime. + """ + + _tags = { + "object_type": "datatype", + "scitype": None, + "name": None, # any string + "name_python": None, # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": None, + } + + def __init__(self): + super().__init__() + + def check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + If self has parameters set, the check will in addition + check whether metadata of obj is equal to self's parameters. + In this case, ``return_metadata`` will always include the + metadata fields required to check the parameters. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : instance of self only returned if return_metadata is True. + Metadata dictionary. + """ + self_params = self.get_params() + + need_check = [k for k in self_params if self_params[k] is not None] + self_dict = {k: self_params[k] for k in need_check} + + return_metadata_orig = return_metadata + + # update return_metadata to retrieve any self_params + # return_metadata_bool has updated condition + if not len(need_check) == 0: + if isinstance(return_metadata, bool): + if not return_metadata: + return_metadata = need_check + return_metadata_bool = True + else: + return_metadata = set(return_metadata).union(need_check) + return_metadata = list(return_metadata) + return_metadata_bool = True + elif isinstance(return_metadata, bool): + return_metadata_bool = return_metadata + else: + return_metadata_bool = True + + # call inner _check + check_res = self._check( + obj=obj, return_metadata=return_metadata, var_name=var_name + ) + + if return_metadata_bool: + valid = check_res[0] + msg = check_res[1] + metadata = check_res[2] + else: + valid = check_res + msg = "" + + if not valid: + return _ret(False, msg, None, return_metadata_orig) + + # now we know the check is valid, but we need to compare fields + metadata_sub = {k: metadata[k] for k in self_dict} + eqs, msg = deep_equals(self_dict, metadata_sub, return_msg=True) + if not eqs: + msg = f"metadata of type unequal, {msg}" + return _ret(False, msg, None, return_metadata_orig) + + self_type = type(self)(**metadata) + return _ret(True, "", self_type, return_metadata_orig) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + raise NotImplementedError + + def __getitem__(self, key): + """Get attribute by key. + + Parameters + ---------- + key : str + Attribute name. + + Returns + ------- + value : any + Attribute value. + """ + return getattr(self, key) + + def get(self, key, default=None): + """Get attribute by key. + + Parameters + ---------- + key : str + Attribute name. + default : any, optional (default=None) + Default value if attribute does not exist. + + Returns + ------- + value : any + Attribute value. + """ + return getattr(self, key, default) + + +class BaseConverter(BaseObject): + """Base class for data type converters. + + This class is the base class for all data type converters in sktime. + """ + + _tags = { + "object_type": "converter", + "scitype": None, + "mtype_from": None, # equal to name field + "mtype_to": None, # equal to name field + "python_version": None, + "python_dependencies": None, + } + + def __init__(self): + super().__init__() + + def convert(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + """ + return self._convert(obj, store) + + def _convert(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + """ + raise NotImplementedError From b35a1fa35e1a8b2a2232a1e5fe4937bf0b97821b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Jun 2024 16:29:31 +0100 Subject: [PATCH 03/40] base table --- skpro/datatypes/_table/_base.py | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 skpro/datatypes/_table/_base.py diff --git a/skpro/datatypes/_table/_base.py b/skpro/datatypes/_table/_base.py new file mode 100644 index 000000000..3007be8de --- /dev/null +++ b/skpro/datatypes/_table/_base.py @@ -0,0 +1,55 @@ +# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +"""Base class for data types.""" + +__author__ = ["fkiraly"] + +from skpro.datatypes._base import BaseDatatype + + +class BaseTable(BaseDatatype): + """Base class for Table data types. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": None, # any string + "name_python": None, # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": None, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + self.is_univariate = is_univariate + self.is_empty = is_empty + self.has_nans = has_nans + self.n_instances = n_instances + self.n_features = n_features + self.feature_names = feature_names + + super().__init__() From 00a57b97195f6cdff83805b959a4be23bfa83f0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Jun 2024 16:30:39 +0100 Subject: [PATCH 04/40] docstr --- skpro/datatypes/_table/_base.py | 2 +- skpro/utils/retrieval.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/skpro/datatypes/_table/_base.py b/skpro/datatypes/_table/_base.py index 3007be8de..07dfb5c73 100644 --- a/skpro/datatypes/_table/_base.py +++ b/skpro/datatypes/_table/_base.py @@ -1,4 +1,4 @@ -# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) """Base class for data types.""" __author__ = ["fkiraly"] diff --git a/skpro/utils/retrieval.py b/skpro/utils/retrieval.py index b9c7902cb..54eac58b8 100644 --- a/skpro/utils/retrieval.py +++ b/skpro/utils/retrieval.py @@ -1,3 +1,4 @@ +"""Utility functions for retrieving objects from modules.""" import importlib import inspect import pkgutil From f0f80fe8b5063b64990c75471cf4fc9bfcd4d0f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Jun 2024 19:32:01 +0100 Subject: [PATCH 05/40] Revert "[MNT] increase `numpy` bound to `numpy < 2.1`, `numpy 2` compatibility (#393)" This reverts commit 0abc014f365a73083908841059a06aa0b4133ac0. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 38493af1b..adbf59aea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.13" dependencies = [ - "numpy>=1.21.0,<2.1", + "numpy>=1.21.0,<1.27", "pandas>=1.1.0,<2.3.0", "packaging", "scikit-base>=0.6.1,<0.9.0", From cb3ae051f37b1fa7ad186cfec013b00e926f48f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Jun 2024 19:32:04 +0100 Subject: [PATCH 06/40] Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `numpy 2` compatibility (#393)" This reverts commit f0f80fe8b5063b64990c75471cf4fc9bfcd4d0f7. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index adbf59aea..38493af1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.13" dependencies = [ - "numpy>=1.21.0,<1.27", + "numpy>=1.21.0,<2.1", "pandas>=1.1.0,<2.3.0", "packaging", "scikit-base>=0.6.1,<0.9.0", From 449b579e067678cc68593d28831c6a3dffc9c247 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Jun 2024 19:32:38 +0100 Subject: [PATCH 07/40] Revert "Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `numpy 2` compatibility (#393)"" This reverts commit cb3ae051f37b1fa7ad186cfec013b00e926f48f4. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 38493af1b..adbf59aea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.13" dependencies = [ - "numpy>=1.21.0,<2.1", + "numpy>=1.21.0,<1.27", "pandas>=1.1.0,<2.3.0", "packaging", "scikit-base>=0.6.1,<0.9.0", From e16e87a1ee22b8aa1feb370ca84d47e13be2ed61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Jun 2024 19:32:41 +0100 Subject: [PATCH 08/40] Reapply "Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `numpy 2` compatibility (#393)"" This reverts commit 449b579e067678cc68593d28831c6a3dffc9c247. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index adbf59aea..38493af1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.13" dependencies = [ - "numpy>=1.21.0,<1.27", + "numpy>=1.21.0,<2.1", "pandas>=1.1.0,<2.3.0", "packaging", "scikit-base>=0.6.1,<0.9.0", From aa68bb8b5f54f0ce59e4478f07376283a73143b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 18:54:36 +0100 Subject: [PATCH 09/40] continuing refactor --- skpro/datatypes/_base/__init__.py | 2 +- skpro/datatypes/_check.py | 69 +++- skpro/datatypes/_table/_check.py | 552 +++++++++++++++++++++++++++--- 3 files changed, 577 insertions(+), 46 deletions(-) diff --git a/skpro/datatypes/_base/__init__.py b/skpro/datatypes/_base/__init__.py index e04eb34d9..b6727e858 100644 --- a/skpro/datatypes/_base/__init__.py +++ b/skpro/datatypes/_base/__init__.py @@ -1,5 +1,5 @@ """Base module for datatypes.""" -from sktime.datatypes._base._base import BaseConverter, BaseDatatype +from skpro.datatypes._base._base import BaseConverter, BaseDatatype __all__ = ["BaseConverter", "BaseDatatype"] diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index 1aff1473a..d50eb6bd7 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -23,21 +23,80 @@ "mtype", ] +import importlib +import inspect +import pkgutil + import numpy as np +from skpro.datatypes._base import BaseDatatype from skpro.datatypes._common import _metadata_requested, _ret from skpro.datatypes._proba import check_dict_Proba from skpro.datatypes._registry import AMBIGUOUS_MTYPES, SCITYPE_LIST, mtype_to_scitype -from skpro.datatypes._table import check_dict_Table -# pool convert_dict-s -check_dict = dict() -check_dict.update(check_dict_Table) -check_dict.update(check_dict_Proba) + +check_dict = {} + + +def get_check_dict(): + """Retrieve check_dict, caches the first time it is requested. + + This is to avoid repeated, time consuming crawling in generate_check_dict, + which would otherwise be called every time check_dict is requested. + + Leaving the code on root level will also fail, due to circular imports. + """ + if len(check_dict) == 0: + check_dict.update(generate_check_dict()) + return check_dict + + +def generate_check_dict(): + """Generate check_dict using lookup.""" + from skbase.utils.dependencies import _check_estimator_deps + + from skpro import datatypes + + mod = datatypes + + classes = [] + for _, name, _ in pkgutil.walk_packages(mod.__path__, prefix=mod.__name__ + "."): + submodule = importlib.import_module(name) + for _, obj in inspect.getmembers(submodule): + if inspect.isclass(obj): + if not obj.__name__.startswith("Base"): + classes.append(obj) + classes = [x for x in classes if issubclass(x, BaseDatatype) and x != BaseDatatype] + + # this does not work, but should - bug in skbase? + # ROOT = str(Path(__file__).parent) # sktime package root directory + # + # result = all_objects( + # object_types=BaseDatatype, + # package_name="sktime.datatypes", + # path=ROOT, + # return_names=False, + # ) + + # subset only to data types with soft dependencies present + result = [x for x in classes if _check_estimator_deps(x, severity="none")] + + check_dict = dict() + for k in result: + mtype = k.get_class_tag("name") + scitype = k.get_class_tag("scitype") + + check_dict[(mtype, scitype)] = k()._check + + # temporary while refactoring + check_dict.update(check_dict_Proba) + + return check_dict def _check_scitype_valid(scitype: str = None): """Check validity of scitype.""" + check_dict = get_check_dict() valid_scitypes = list({x[1] for x in check_dict.keys()}) if not isinstance(scitype, str): diff --git a/skpro/datatypes/_table/_check.py b/skpro/datatypes/_table/_check.py index 569f6b3c9..9ce1a4e57 100644 --- a/skpro/datatypes/_table/_check.py +++ b/skpro/datatypes/_table/_check.py @@ -1,14 +1,6 @@ -"""Machine type checkers for Table scitype. +"""Machine type cclasses for Table scitype. -Exports checkers for Table scitype: - -check_dict: dict indexed by pairs of str - 1st element = mtype - str - 2nd element = scitype - str -elements are checker/validation functions for mtype - -Function signature of all elements -check_dict[(mtype, scitype)] +Checks for each class are defined in the "check" method, of signature: Parameters ---------- @@ -36,20 +28,90 @@ __author__ = ["fkiraly"] -__all__ = ["check_dict"] - import numpy as np import pandas as pd from skpro.datatypes._common import _req, _ret -from skpro.utils.validation._dependencies import _check_soft_dependencies - -check_dict = dict() +from skpro.datatypes._table._base import BaseTable PRIMITIVE_TYPES = (float, int, str) +class TablePdDataFrame(BaseTable): + """Data type: pandas.DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "pd_DataFrame_Table", # any string + "name_python": "table_pd_df", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "pandas", + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_pddataframe_table(obj, return_metadata, var_name) + + def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): metadata = dict() @@ -80,7 +142,79 @@ def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("pd_DataFrame_Table", "Table")] = check_pddataframe_table +class TablePdSeries(BaseTable): + """Data type: pandas.Series based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "pd_Series_Table", # any string + "name_python": "table_pd_series", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "pandas", + "capability:multivariate": True, + "capability:unequally_spaced": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_pdseries_table(obj, return_metadata, var_name) def check_pdseries_table(obj, return_metadata=False, var_name="obj"): @@ -119,7 +253,78 @@ def check_pdseries_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("pd_Series_Table", "Table")] = check_pdseries_table +class TableNp1D(BaseTable): + """Data type: 1D np.ndarray based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "numpy1D", # any string + "name_python": "table_numpy1d", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": False, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_numpy1d_table(obj, return_metadata, var_name) def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): @@ -153,7 +358,78 @@ def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("numpy1D", "Table")] = check_numpy1d_table +class TableNp2D(BaseTable): + """Data type: 2D np.ndarray based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "numpy2D", # any string + "name_python": "table_numpy2d", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_numpy2d_table(obj, return_metadata, var_name) def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): @@ -186,7 +462,78 @@ def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("numpy2D", "Table")] = check_numpy2d_table +class TableListOfDict(BaseTable): + """Data type: list of dict based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "list_of_dict", # any string + "name_python": "table_list_of_dict", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_list_of_dict_table(obj, return_metadata, var_name) def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): @@ -242,28 +589,153 @@ def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("list_of_dict", "Table")] = check_list_of_dict_table - - -if _check_soft_dependencies(["polars", "pyarrow"], severity="none"): - from skpro.datatypes._adapter.polars import check_polars_frame - - def check_polars_table(obj, return_metadata=False, var_name="obj"): - return check_polars_frame( - obj=obj, - return_metadata=return_metadata, - var_name=var_name, - lazy=False, +class TablePolarsEager(BaseTable): + """Data type: eager polars DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "polars_eager_table", # any string + "name_python": "table_polars_eager", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": ["polars", "pyarrow"], + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, ) - check_dict[("polars_eager_table", "Table")] = check_polars_table - - def check_polars_table_lazy(obj, return_metadata=False, var_name="obj"): - return check_polars_frame( - obj=obj, - return_metadata=return_metadata, - var_name=var_name, - lazy=True, + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + from skpro.datatypes._adapter.polars import check_polars_frame + + return check_polars_frame(obj, return_metadata, var_name, lazy=False) + + +class TablePolarsLazy(BaseTable): + """Data type: lazy polars DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "polars_lazy_table", # any string + "name_python": "table_polars_lazy", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": ["polars", "pyarrow"], + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, ) - check_dict[("polars_lazy_table", "Table")] = check_polars_table_lazy + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + from skpro.datatypes._adapter.polars import check_polars_frame + + return check_polars_frame(obj, return_metadata, var_name, lazy=True) From 64acb3ea0ac7d3e1c8a843e759e625de0c473095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 18:56:17 +0100 Subject: [PATCH 10/40] Revert "continuing refactor" This reverts commit aa68bb8b5f54f0ce59e4478f07376283a73143b8. --- skpro/datatypes/_base/__init__.py | 2 +- skpro/datatypes/_check.py | 69 +--- skpro/datatypes/_table/_check.py | 552 +++--------------------------- 3 files changed, 46 insertions(+), 577 deletions(-) diff --git a/skpro/datatypes/_base/__init__.py b/skpro/datatypes/_base/__init__.py index b6727e858..e04eb34d9 100644 --- a/skpro/datatypes/_base/__init__.py +++ b/skpro/datatypes/_base/__init__.py @@ -1,5 +1,5 @@ """Base module for datatypes.""" -from skpro.datatypes._base._base import BaseConverter, BaseDatatype +from sktime.datatypes._base._base import BaseConverter, BaseDatatype __all__ = ["BaseConverter", "BaseDatatype"] diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index d50eb6bd7..1aff1473a 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -23,80 +23,21 @@ "mtype", ] -import importlib -import inspect -import pkgutil - import numpy as np -from skpro.datatypes._base import BaseDatatype from skpro.datatypes._common import _metadata_requested, _ret from skpro.datatypes._proba import check_dict_Proba from skpro.datatypes._registry import AMBIGUOUS_MTYPES, SCITYPE_LIST, mtype_to_scitype +from skpro.datatypes._table import check_dict_Table - -check_dict = {} - - -def get_check_dict(): - """Retrieve check_dict, caches the first time it is requested. - - This is to avoid repeated, time consuming crawling in generate_check_dict, - which would otherwise be called every time check_dict is requested. - - Leaving the code on root level will also fail, due to circular imports. - """ - if len(check_dict) == 0: - check_dict.update(generate_check_dict()) - return check_dict - - -def generate_check_dict(): - """Generate check_dict using lookup.""" - from skbase.utils.dependencies import _check_estimator_deps - - from skpro import datatypes - - mod = datatypes - - classes = [] - for _, name, _ in pkgutil.walk_packages(mod.__path__, prefix=mod.__name__ + "."): - submodule = importlib.import_module(name) - for _, obj in inspect.getmembers(submodule): - if inspect.isclass(obj): - if not obj.__name__.startswith("Base"): - classes.append(obj) - classes = [x for x in classes if issubclass(x, BaseDatatype) and x != BaseDatatype] - - # this does not work, but should - bug in skbase? - # ROOT = str(Path(__file__).parent) # sktime package root directory - # - # result = all_objects( - # object_types=BaseDatatype, - # package_name="sktime.datatypes", - # path=ROOT, - # return_names=False, - # ) - - # subset only to data types with soft dependencies present - result = [x for x in classes if _check_estimator_deps(x, severity="none")] - - check_dict = dict() - for k in result: - mtype = k.get_class_tag("name") - scitype = k.get_class_tag("scitype") - - check_dict[(mtype, scitype)] = k()._check - - # temporary while refactoring - check_dict.update(check_dict_Proba) - - return check_dict +# pool convert_dict-s +check_dict = dict() +check_dict.update(check_dict_Table) +check_dict.update(check_dict_Proba) def _check_scitype_valid(scitype: str = None): """Check validity of scitype.""" - check_dict = get_check_dict() valid_scitypes = list({x[1] for x in check_dict.keys()}) if not isinstance(scitype, str): diff --git a/skpro/datatypes/_table/_check.py b/skpro/datatypes/_table/_check.py index 9ce1a4e57..569f6b3c9 100644 --- a/skpro/datatypes/_table/_check.py +++ b/skpro/datatypes/_table/_check.py @@ -1,6 +1,14 @@ -"""Machine type cclasses for Table scitype. +"""Machine type checkers for Table scitype. -Checks for each class are defined in the "check" method, of signature: +Exports checkers for Table scitype: + +check_dict: dict indexed by pairs of str + 1st element = mtype - str + 2nd element = scitype - str +elements are checker/validation functions for mtype + +Function signature of all elements +check_dict[(mtype, scitype)] Parameters ---------- @@ -28,88 +36,18 @@ __author__ = ["fkiraly"] +__all__ = ["check_dict"] + import numpy as np import pandas as pd from skpro.datatypes._common import _req, _ret -from skpro.datatypes._table._base import BaseTable - - -PRIMITIVE_TYPES = (float, int, str) +from skpro.utils.validation._dependencies import _check_soft_dependencies +check_dict = dict() -class TablePdDataFrame(BaseTable): - """Data type: pandas.DataFrame based specification of data frame table. - - Parameters are inferred by check. - - Parameters - ---------- - is_univariate: bool - True iff series has one variable - is_empty: bool - True iff series has no variables or no instances - has_nans: bool - True iff the series contains NaN values - n_instances: int - number of instances/rows in the table - n_features: int - number of variables in series - feature_names: list of int or object - names of variables in series - """ - - _tags = { - "scitype": "Table", - "name": "pd_DataFrame_Table", # any string - "name_python": "table_pd_df", # lower_snake_case - "name_aliases": [], - "python_version": None, - "python_dependencies": "pandas", - "capability:multivariate": True, - "capability:missing_values": True, - } - - def __init__( - self, - is_univariate=None, - is_empty=None, - has_nans=None, - n_instances=None, - n_features=None, - feature_names=None, - ): - super().__init__( - is_univariate=is_univariate, - n_instances=n_instances, - is_empty=is_empty, - has_nans=has_nans, - n_features=n_features, - feature_names=feature_names, - ) - def _check(self, obj, return_metadata=False, var_name="obj"): - """Check if obj is of this data type. - - Parameters - ---------- - obj : any - Object to check. - return_metadata : bool, optional (default=False) - Whether to return metadata. - var_name : str, optional (default="obj") - Name of the variable to check, for use in error messages. - - Returns - ------- - valid : bool - Whether obj is of this data type. - msg : str, only returned if return_metadata is True. - Error message if obj is not of this data type. - metadata : dict, only returned if return_metadata is True. - Metadata dictionary. - """ - return check_pddataframe_table(obj, return_metadata, var_name) +PRIMITIVE_TYPES = (float, int, str) def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): @@ -142,79 +80,7 @@ def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -class TablePdSeries(BaseTable): - """Data type: pandas.Series based specification of data frame table. - - Parameters are inferred by check. - - Parameters - ---------- - is_univariate: bool - True iff series has one variable - is_empty: bool - True iff series has no variables or no instances - has_nans: bool - True iff the series contains NaN values - n_instances: int - number of instances/rows in the table - n_features: int - number of variables in series - feature_names: list of int or object - names of variables in series - """ - - _tags = { - "scitype": "Table", - "name": "pd_Series_Table", # any string - "name_python": "table_pd_series", # lower_snake_case - "name_aliases": [], - "python_version": None, - "python_dependencies": "pandas", - "capability:multivariate": True, - "capability:unequally_spaced": True, - "capability:missing_values": True, - } - - def __init__( - self, - is_univariate=None, - is_empty=None, - has_nans=None, - n_instances=None, - n_features=None, - feature_names=None, - ): - super().__init__( - is_univariate=is_univariate, - n_instances=n_instances, - is_empty=is_empty, - has_nans=has_nans, - n_features=n_features, - feature_names=feature_names, - ) - - def _check(self, obj, return_metadata=False, var_name="obj"): - """Check if obj is of this data type. - - Parameters - ---------- - obj : any - Object to check. - return_metadata : bool, optional (default=False) - Whether to return metadata. - var_name : str, optional (default="obj") - Name of the variable to check, for use in error messages. - - Returns - ------- - valid : bool - Whether obj is of this data type. - msg : str, only returned if return_metadata is True. - Error message if obj is not of this data type. - metadata : dict, only returned if return_metadata is True. - Metadata dictionary. - """ - return check_pdseries_table(obj, return_metadata, var_name) +check_dict[("pd_DataFrame_Table", "Table")] = check_pddataframe_table def check_pdseries_table(obj, return_metadata=False, var_name="obj"): @@ -253,78 +119,7 @@ def check_pdseries_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -class TableNp1D(BaseTable): - """Data type: 1D np.ndarray based specification of data frame table. - - Parameters are inferred by check. - - Parameters - ---------- - is_univariate: bool - True iff series has one variable - is_empty: bool - True iff series has no variables or no instances - has_nans: bool - True iff the series contains NaN values - n_instances: int - number of instances/rows in the table - n_features: int - number of variables in series - feature_names: list of int or object - names of variables in series - """ - - _tags = { - "scitype": "Table", - "name": "numpy1D", # any string - "name_python": "table_numpy1d", # lower_snake_case - "name_aliases": [], - "python_version": None, - "python_dependencies": "numpy", - "capability:multivariate": False, - "capability:missing_values": True, - } - - def __init__( - self, - is_univariate=None, - is_empty=None, - has_nans=None, - n_instances=None, - n_features=None, - feature_names=None, - ): - super().__init__( - is_univariate=is_univariate, - n_instances=n_instances, - is_empty=is_empty, - has_nans=has_nans, - n_features=n_features, - feature_names=feature_names, - ) - - def _check(self, obj, return_metadata=False, var_name="obj"): - """Check if obj is of this data type. - - Parameters - ---------- - obj : any - Object to check. - return_metadata : bool, optional (default=False) - Whether to return metadata. - var_name : str, optional (default="obj") - Name of the variable to check, for use in error messages. - - Returns - ------- - valid : bool - Whether obj is of this data type. - msg : str, only returned if return_metadata is True. - Error message if obj is not of this data type. - metadata : dict, only returned if return_metadata is True. - Metadata dictionary. - """ - return check_numpy1d_table(obj, return_metadata, var_name) +check_dict[("pd_Series_Table", "Table")] = check_pdseries_table def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): @@ -358,78 +153,7 @@ def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -class TableNp2D(BaseTable): - """Data type: 2D np.ndarray based specification of data frame table. - - Parameters are inferred by check. - - Parameters - ---------- - is_univariate: bool - True iff series has one variable - is_empty: bool - True iff series has no variables or no instances - has_nans: bool - True iff the series contains NaN values - n_instances: int - number of instances/rows in the table - n_features: int - number of variables in series - feature_names: list of int or object - names of variables in series - """ - - _tags = { - "scitype": "Table", - "name": "numpy2D", # any string - "name_python": "table_numpy2d", # lower_snake_case - "name_aliases": [], - "python_version": None, - "python_dependencies": "numpy", - "capability:multivariate": True, - "capability:missing_values": True, - } - - def __init__( - self, - is_univariate=None, - is_empty=None, - has_nans=None, - n_instances=None, - n_features=None, - feature_names=None, - ): - super().__init__( - is_univariate=is_univariate, - n_instances=n_instances, - is_empty=is_empty, - has_nans=has_nans, - n_features=n_features, - feature_names=feature_names, - ) - - def _check(self, obj, return_metadata=False, var_name="obj"): - """Check if obj is of this data type. - - Parameters - ---------- - obj : any - Object to check. - return_metadata : bool, optional (default=False) - Whether to return metadata. - var_name : str, optional (default="obj") - Name of the variable to check, for use in error messages. - - Returns - ------- - valid : bool - Whether obj is of this data type. - msg : str, only returned if return_metadata is True. - Error message if obj is not of this data type. - metadata : dict, only returned if return_metadata is True. - Metadata dictionary. - """ - return check_numpy2d_table(obj, return_metadata, var_name) +check_dict[("numpy1D", "Table")] = check_numpy1d_table def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): @@ -462,78 +186,7 @@ def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -class TableListOfDict(BaseTable): - """Data type: list of dict based specification of data frame table. - - Parameters are inferred by check. - - Parameters - ---------- - is_univariate: bool - True iff series has one variable - is_empty: bool - True iff series has no variables or no instances - has_nans: bool - True iff the series contains NaN values - n_instances: int - number of instances/rows in the table - n_features: int - number of variables in series - feature_names: list of int or object - names of variables in series - """ - - _tags = { - "scitype": "Table", - "name": "list_of_dict", # any string - "name_python": "table_list_of_dict", # lower_snake_case - "name_aliases": [], - "python_version": None, - "python_dependencies": "numpy", - "capability:multivariate": True, - "capability:missing_values": True, - } - - def __init__( - self, - is_univariate=None, - is_empty=None, - has_nans=None, - n_instances=None, - n_features=None, - feature_names=None, - ): - super().__init__( - is_univariate=is_univariate, - n_instances=n_instances, - is_empty=is_empty, - has_nans=has_nans, - n_features=n_features, - feature_names=feature_names, - ) - - def _check(self, obj, return_metadata=False, var_name="obj"): - """Check if obj is of this data type. - - Parameters - ---------- - obj : any - Object to check. - return_metadata : bool, optional (default=False) - Whether to return metadata. - var_name : str, optional (default="obj") - Name of the variable to check, for use in error messages. - - Returns - ------- - valid : bool - Whether obj is of this data type. - msg : str, only returned if return_metadata is True. - Error message if obj is not of this data type. - metadata : dict, only returned if return_metadata is True. - Metadata dictionary. - """ - return check_list_of_dict_table(obj, return_metadata, var_name) +check_dict[("numpy2D", "Table")] = check_numpy2d_table def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): @@ -589,153 +242,28 @@ def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -class TablePolarsEager(BaseTable): - """Data type: eager polars DataFrame based specification of data frame table. - - Parameters are inferred by check. - - Parameters - ---------- - is_univariate: bool - True iff series has one variable - is_empty: bool - True iff series has no variables or no instances - has_nans: bool - True iff the series contains NaN values - n_instances: int - number of instances/rows in the table - n_features: int - number of variables in series - feature_names: list of int or object - names of variables in series - """ - - _tags = { - "scitype": "Table", - "name": "polars_eager_table", # any string - "name_python": "table_polars_eager", # lower_snake_case - "name_aliases": [], - "python_version": None, - "python_dependencies": ["polars", "pyarrow"], - "capability:multivariate": True, - "capability:missing_values": True, - } - - def __init__( - self, - is_univariate=None, - is_empty=None, - has_nans=None, - n_instances=None, - n_features=None, - feature_names=None, - ): - super().__init__( - is_univariate=is_univariate, - n_instances=n_instances, - is_empty=is_empty, - has_nans=has_nans, - n_features=n_features, - feature_names=feature_names, +check_dict[("list_of_dict", "Table")] = check_list_of_dict_table + + +if _check_soft_dependencies(["polars", "pyarrow"], severity="none"): + from skpro.datatypes._adapter.polars import check_polars_frame + + def check_polars_table(obj, return_metadata=False, var_name="obj"): + return check_polars_frame( + obj=obj, + return_metadata=return_metadata, + var_name=var_name, + lazy=False, ) - def _check(self, obj, return_metadata=False, var_name="obj"): - """Check if obj is of this data type. - - Parameters - ---------- - obj : any - Object to check. - return_metadata : bool, optional (default=False) - Whether to return metadata. - var_name : str, optional (default="obj") - Name of the variable to check, for use in error messages. - - Returns - ------- - valid : bool - Whether obj is of this data type. - msg : str, only returned if return_metadata is True. - Error message if obj is not of this data type. - metadata : dict, only returned if return_metadata is True. - Metadata dictionary. - """ - from skpro.datatypes._adapter.polars import check_polars_frame - - return check_polars_frame(obj, return_metadata, var_name, lazy=False) - - -class TablePolarsLazy(BaseTable): - """Data type: lazy polars DataFrame based specification of data frame table. - - Parameters are inferred by check. - - Parameters - ---------- - is_univariate: bool - True iff series has one variable - is_empty: bool - True iff series has no variables or no instances - has_nans: bool - True iff the series contains NaN values - n_instances: int - number of instances/rows in the table - n_features: int - number of variables in series - feature_names: list of int or object - names of variables in series - """ - - _tags = { - "scitype": "Table", - "name": "polars_lazy_table", # any string - "name_python": "table_polars_lazy", # lower_snake_case - "name_aliases": [], - "python_version": None, - "python_dependencies": ["polars", "pyarrow"], - "capability:multivariate": True, - "capability:missing_values": True, - } - - def __init__( - self, - is_univariate=None, - is_empty=None, - has_nans=None, - n_instances=None, - n_features=None, - feature_names=None, - ): - super().__init__( - is_univariate=is_univariate, - n_instances=n_instances, - is_empty=is_empty, - has_nans=has_nans, - n_features=n_features, - feature_names=feature_names, + check_dict[("polars_eager_table", "Table")] = check_polars_table + + def check_polars_table_lazy(obj, return_metadata=False, var_name="obj"): + return check_polars_frame( + obj=obj, + return_metadata=return_metadata, + var_name=var_name, + lazy=True, ) - def _check(self, obj, return_metadata=False, var_name="obj"): - """Check if obj is of this data type. - - Parameters - ---------- - obj : any - Object to check. - return_metadata : bool, optional (default=False) - Whether to return metadata. - var_name : str, optional (default="obj") - Name of the variable to check, for use in error messages. - - Returns - ------- - valid : bool - Whether obj is of this data type. - msg : str, only returned if return_metadata is True. - Error message if obj is not of this data type. - metadata : dict, only returned if return_metadata is True. - Metadata dictionary. - """ - from skpro.datatypes._adapter.polars import check_polars_frame - - return check_polars_frame(obj, return_metadata, var_name, lazy=True) + check_dict[("polars_lazy_table", "Table")] = check_polars_table_lazy From 5a05ca229ae2374b8b0fcc906ed948cd328a9c23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 18:56:20 +0100 Subject: [PATCH 11/40] Reapply "continuing refactor" This reverts commit 64acb3ea0ac7d3e1c8a843e759e625de0c473095. --- skpro/datatypes/_base/__init__.py | 2 +- skpro/datatypes/_check.py | 69 +++- skpro/datatypes/_table/_check.py | 552 +++++++++++++++++++++++++++--- 3 files changed, 577 insertions(+), 46 deletions(-) diff --git a/skpro/datatypes/_base/__init__.py b/skpro/datatypes/_base/__init__.py index e04eb34d9..b6727e858 100644 --- a/skpro/datatypes/_base/__init__.py +++ b/skpro/datatypes/_base/__init__.py @@ -1,5 +1,5 @@ """Base module for datatypes.""" -from sktime.datatypes._base._base import BaseConverter, BaseDatatype +from skpro.datatypes._base._base import BaseConverter, BaseDatatype __all__ = ["BaseConverter", "BaseDatatype"] diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index 1aff1473a..d50eb6bd7 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -23,21 +23,80 @@ "mtype", ] +import importlib +import inspect +import pkgutil + import numpy as np +from skpro.datatypes._base import BaseDatatype from skpro.datatypes._common import _metadata_requested, _ret from skpro.datatypes._proba import check_dict_Proba from skpro.datatypes._registry import AMBIGUOUS_MTYPES, SCITYPE_LIST, mtype_to_scitype -from skpro.datatypes._table import check_dict_Table -# pool convert_dict-s -check_dict = dict() -check_dict.update(check_dict_Table) -check_dict.update(check_dict_Proba) + +check_dict = {} + + +def get_check_dict(): + """Retrieve check_dict, caches the first time it is requested. + + This is to avoid repeated, time consuming crawling in generate_check_dict, + which would otherwise be called every time check_dict is requested. + + Leaving the code on root level will also fail, due to circular imports. + """ + if len(check_dict) == 0: + check_dict.update(generate_check_dict()) + return check_dict + + +def generate_check_dict(): + """Generate check_dict using lookup.""" + from skbase.utils.dependencies import _check_estimator_deps + + from skpro import datatypes + + mod = datatypes + + classes = [] + for _, name, _ in pkgutil.walk_packages(mod.__path__, prefix=mod.__name__ + "."): + submodule = importlib.import_module(name) + for _, obj in inspect.getmembers(submodule): + if inspect.isclass(obj): + if not obj.__name__.startswith("Base"): + classes.append(obj) + classes = [x for x in classes if issubclass(x, BaseDatatype) and x != BaseDatatype] + + # this does not work, but should - bug in skbase? + # ROOT = str(Path(__file__).parent) # sktime package root directory + # + # result = all_objects( + # object_types=BaseDatatype, + # package_name="sktime.datatypes", + # path=ROOT, + # return_names=False, + # ) + + # subset only to data types with soft dependencies present + result = [x for x in classes if _check_estimator_deps(x, severity="none")] + + check_dict = dict() + for k in result: + mtype = k.get_class_tag("name") + scitype = k.get_class_tag("scitype") + + check_dict[(mtype, scitype)] = k()._check + + # temporary while refactoring + check_dict.update(check_dict_Proba) + + return check_dict def _check_scitype_valid(scitype: str = None): """Check validity of scitype.""" + check_dict = get_check_dict() valid_scitypes = list({x[1] for x in check_dict.keys()}) if not isinstance(scitype, str): diff --git a/skpro/datatypes/_table/_check.py b/skpro/datatypes/_table/_check.py index 569f6b3c9..9ce1a4e57 100644 --- a/skpro/datatypes/_table/_check.py +++ b/skpro/datatypes/_table/_check.py @@ -1,14 +1,6 @@ -"""Machine type checkers for Table scitype. +"""Machine type cclasses for Table scitype. -Exports checkers for Table scitype: - -check_dict: dict indexed by pairs of str - 1st element = mtype - str - 2nd element = scitype - str -elements are checker/validation functions for mtype - -Function signature of all elements -check_dict[(mtype, scitype)] +Checks for each class are defined in the "check" method, of signature: Parameters ---------- @@ -36,20 +28,90 @@ __author__ = ["fkiraly"] -__all__ = ["check_dict"] - import numpy as np import pandas as pd from skpro.datatypes._common import _req, _ret -from skpro.utils.validation._dependencies import _check_soft_dependencies - -check_dict = dict() +from skpro.datatypes._table._base import BaseTable PRIMITIVE_TYPES = (float, int, str) +class TablePdDataFrame(BaseTable): + """Data type: pandas.DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "pd_DataFrame_Table", # any string + "name_python": "table_pd_df", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "pandas", + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_pddataframe_table(obj, return_metadata, var_name) + + def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): metadata = dict() @@ -80,7 +142,79 @@ def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("pd_DataFrame_Table", "Table")] = check_pddataframe_table +class TablePdSeries(BaseTable): + """Data type: pandas.Series based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "pd_Series_Table", # any string + "name_python": "table_pd_series", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "pandas", + "capability:multivariate": True, + "capability:unequally_spaced": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_pdseries_table(obj, return_metadata, var_name) def check_pdseries_table(obj, return_metadata=False, var_name="obj"): @@ -119,7 +253,78 @@ def check_pdseries_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("pd_Series_Table", "Table")] = check_pdseries_table +class TableNp1D(BaseTable): + """Data type: 1D np.ndarray based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "numpy1D", # any string + "name_python": "table_numpy1d", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": False, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_numpy1d_table(obj, return_metadata, var_name) def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): @@ -153,7 +358,78 @@ def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("numpy1D", "Table")] = check_numpy1d_table +class TableNp2D(BaseTable): + """Data type: 2D np.ndarray based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "numpy2D", # any string + "name_python": "table_numpy2d", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_numpy2d_table(obj, return_metadata, var_name) def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): @@ -186,7 +462,78 @@ def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("numpy2D", "Table")] = check_numpy2d_table +class TableListOfDict(BaseTable): + """Data type: list of dict based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "list_of_dict", # any string + "name_python": "table_list_of_dict", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return check_list_of_dict_table(obj, return_metadata, var_name) def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): @@ -242,28 +589,153 @@ def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("list_of_dict", "Table")] = check_list_of_dict_table - - -if _check_soft_dependencies(["polars", "pyarrow"], severity="none"): - from skpro.datatypes._adapter.polars import check_polars_frame - - def check_polars_table(obj, return_metadata=False, var_name="obj"): - return check_polars_frame( - obj=obj, - return_metadata=return_metadata, - var_name=var_name, - lazy=False, +class TablePolarsEager(BaseTable): + """Data type: eager polars DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "polars_eager_table", # any string + "name_python": "table_polars_eager", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": ["polars", "pyarrow"], + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, ) - check_dict[("polars_eager_table", "Table")] = check_polars_table - - def check_polars_table_lazy(obj, return_metadata=False, var_name="obj"): - return check_polars_frame( - obj=obj, - return_metadata=return_metadata, - var_name=var_name, - lazy=True, + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + from skpro.datatypes._adapter.polars import check_polars_frame + + return check_polars_frame(obj, return_metadata, var_name, lazy=False) + + +class TablePolarsLazy(BaseTable): + """Data type: lazy polars DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff series has one variable + is_empty: bool + True iff series has no variables or no instances + has_nans: bool + True iff the series contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in series + feature_names: list of int or object + names of variables in series + """ + + _tags = { + "scitype": "Table", + "name": "polars_lazy_table", # any string + "name_python": "table_polars_lazy", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": ["polars", "pyarrow"], + "capability:multivariate": True, + "capability:missing_values": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, ) - check_dict[("polars_lazy_table", "Table")] = check_polars_table_lazy + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + from skpro.datatypes._adapter.polars import check_polars_frame + + return check_polars_frame(obj, return_metadata, var_name, lazy=True) From 3b6cdb9c931021b53ce261a43abf9f0f2267f7ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 19:02:12 +0100 Subject: [PATCH 12/40] linting --- skpro/datatypes/_check.py | 1 - skpro/datatypes/_table/_check.py | 21 ++++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index d50eb6bd7..6a3967b95 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -34,7 +34,6 @@ from skpro.datatypes._proba import check_dict_Proba from skpro.datatypes._registry import AMBIGUOUS_MTYPES, SCITYPE_LIST, mtype_to_scitype - check_dict = {} diff --git a/skpro/datatypes/_table/_check.py b/skpro/datatypes/_table/_check.py index 9ce1a4e57..d99b0bdd4 100644 --- a/skpro/datatypes/_table/_check.py +++ b/skpro/datatypes/_table/_check.py @@ -34,7 +34,6 @@ from skpro.datatypes._common import _req, _ret from skpro.datatypes._table._base import BaseTable - PRIMITIVE_TYPES = (float, int, str) @@ -109,10 +108,10 @@ def _check(self, obj, return_metadata=False, var_name="obj"): metadata : dict, only returned if return_metadata is True. Metadata dictionary. """ - return check_pddataframe_table(obj, return_metadata, var_name) + return _check_pddataframe_table(obj, return_metadata, var_name) -def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): +def _check_pddataframe_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, pd.DataFrame): @@ -214,10 +213,10 @@ def _check(self, obj, return_metadata=False, var_name="obj"): metadata : dict, only returned if return_metadata is True. Metadata dictionary. """ - return check_pdseries_table(obj, return_metadata, var_name) + return _check_pdseries_table(obj, return_metadata, var_name) -def check_pdseries_table(obj, return_metadata=False, var_name="obj"): +def _check_pdseries_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, pd.Series): @@ -324,10 +323,10 @@ def _check(self, obj, return_metadata=False, var_name="obj"): metadata : dict, only returned if return_metadata is True. Metadata dictionary. """ - return check_numpy1d_table(obj, return_metadata, var_name) + return _check_numpy1d_table(obj, return_metadata, var_name) -def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): +def _check_numpy1d_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, np.ndarray): @@ -429,10 +428,10 @@ def _check(self, obj, return_metadata=False, var_name="obj"): metadata : dict, only returned if return_metadata is True. Metadata dictionary. """ - return check_numpy2d_table(obj, return_metadata, var_name) + return _check_numpy2d_table(obj, return_metadata, var_name) -def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): +def _check_numpy2d_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, np.ndarray): @@ -533,10 +532,10 @@ def _check(self, obj, return_metadata=False, var_name="obj"): metadata : dict, only returned if return_metadata is True. Metadata dictionary. """ - return check_list_of_dict_table(obj, return_metadata, var_name) + return _check_list_of_dict_table(obj, return_metadata, var_name) -def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): +def _check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, list): From 65a89fbbd16b848c684c966b9c97148a0ac57dd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 19:09:03 +0100 Subject: [PATCH 13/40] remove imports --- skpro/datatypes/_table/__init__.py | 2 -- skpro/tests/test_polars.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/skpro/datatypes/_table/__init__.py b/skpro/datatypes/_table/__init__.py index ef620b0d9..0481dcaee 100644 --- a/skpro/datatypes/_table/__init__.py +++ b/skpro/datatypes/_table/__init__.py @@ -1,6 +1,5 @@ """Module exports: Series type checkers, converters and mtype inference.""" -from skpro.datatypes._table._check import check_dict as check_dict_Table from skpro.datatypes._table._convert import convert_dict as convert_dict_Table from skpro.datatypes._table._examples import example_dict as example_dict_Table from skpro.datatypes._table._examples import ( @@ -12,7 +11,6 @@ from skpro.datatypes._table._registry import MTYPE_LIST_TABLE, MTYPE_REGISTER_TABLE __all__ = [ - "check_dict_Table", "convert_dict_Table", "MTYPE_LIST_TABLE", "MTYPE_REGISTER_TABLE", diff --git a/skpro/tests/test_polars.py b/skpro/tests/test_polars.py index 796454c94..04a75b5e0 100644 --- a/skpro/tests/test_polars.py +++ b/skpro/tests/test_polars.py @@ -10,7 +10,7 @@ if _check_soft_dependencies(["polars", "pyarrow"], severity="none"): import polars as pl - from skpro.datatypes._table._check import check_polars_table + from skpro.datatypes._adapter.polars import check_polars_frame from skpro.datatypes._table._convert import convert_pandas_to_polars_eager TEST_ALPHAS = [0.05, 0.1, 0.25] @@ -67,9 +67,9 @@ def test_polars_eager_conversion_methods( X_train, X_test, y_train = polars_load_diabetes_pandas X_train_pl, X_test_pl, y_train_pl = polars_load_diabetes_polars - assert check_polars_table(X_train_pl) - assert check_polars_table(X_test_pl) - assert check_polars_table(y_train_pl) + assert check_polars_frame(X_train_pl) + assert check_polars_frame(X_test_pl) + assert check_polars_frame(y_train_pl) assert (X_train.values == X_train_pl.to_numpy()).all() assert (X_test.values == X_test_pl.to_numpy()).all() assert (y_train.values == y_train_pl.to_numpy()).all() From b1f1abda3497d318d6e1788cd6e9784863df8422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 19:26:30 +0100 Subject: [PATCH 14/40] docstrings --- skpro/datatypes/_table/_base.py | 10 ++--- skpro/datatypes/_table/_check.py | 70 ++++++++++++++++---------------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/skpro/datatypes/_table/_base.py b/skpro/datatypes/_table/_base.py index 07dfb5c73..3681e1fd0 100644 --- a/skpro/datatypes/_table/_base.py +++ b/skpro/datatypes/_table/_base.py @@ -14,17 +14,17 @@ class BaseTable(BaseDatatype): Parameters ---------- is_univariate: bool - True iff series has one variable + True iff table has one variable is_empty: bool - True iff series has no variables or no instances + True iff table has no variables or no instances has_nans: bool - True iff the series contains NaN values + True iff the table contains NaN values n_instances: int number of instances/rows in the table n_features: int - number of variables in series + number of variables in table feature_names: list of int or object - names of variables in series + names of variables in table """ _tags = { diff --git a/skpro/datatypes/_table/_check.py b/skpro/datatypes/_table/_check.py index d99b0bdd4..09b8c7043 100644 --- a/skpro/datatypes/_table/_check.py +++ b/skpro/datatypes/_table/_check.py @@ -45,17 +45,17 @@ class TablePdDataFrame(BaseTable): Parameters ---------- is_univariate: bool - True iff series has one variable + True iff table has one variable is_empty: bool - True iff series has no variables or no instances + True iff table has no variables or no instances has_nans: bool - True iff the series contains NaN values + True iff the table contains NaN values n_instances: int number of instances/rows in the table n_features: int - number of variables in series + number of variables in table feature_names: list of int or object - names of variables in series + names of variables in table """ _tags = { @@ -149,17 +149,17 @@ class TablePdSeries(BaseTable): Parameters ---------- is_univariate: bool - True iff series has one variable + True iff table has one variable is_empty: bool - True iff series has no variables or no instances + True iff table has no variables or no instances has_nans: bool - True iff the series contains NaN values + True iff the table contains NaN values n_instances: int number of instances/rows in the table n_features: int - number of variables in series + number of variables in table feature_names: list of int or object - names of variables in series + names of variables in table """ _tags = { @@ -260,17 +260,17 @@ class TableNp1D(BaseTable): Parameters ---------- is_univariate: bool - True iff series has one variable + True iff table has one variable is_empty: bool - True iff series has no variables or no instances + True iff table has no variables or no instances has_nans: bool - True iff the series contains NaN values + True iff the table contains NaN values n_instances: int number of instances/rows in the table n_features: int - number of variables in series + number of variables in table feature_names: list of int or object - names of variables in series + names of variables in table """ _tags = { @@ -365,17 +365,17 @@ class TableNp2D(BaseTable): Parameters ---------- is_univariate: bool - True iff series has one variable + True iff table has one variable is_empty: bool - True iff series has no variables or no instances + True iff table has no variables or no instances has_nans: bool - True iff the series contains NaN values + True iff the table contains NaN values n_instances: int number of instances/rows in the table n_features: int - number of variables in series + number of variables in table feature_names: list of int or object - names of variables in series + names of variables in table """ _tags = { @@ -469,17 +469,17 @@ class TableListOfDict(BaseTable): Parameters ---------- is_univariate: bool - True iff series has one variable + True iff table has one variable is_empty: bool - True iff series has no variables or no instances + True iff table has no variables or no instances has_nans: bool - True iff the series contains NaN values + True iff the table contains NaN values n_instances: int number of instances/rows in the table n_features: int - number of variables in series + number of variables in table feature_names: list of int or object - names of variables in series + names of variables in table """ _tags = { @@ -596,17 +596,17 @@ class TablePolarsEager(BaseTable): Parameters ---------- is_univariate: bool - True iff series has one variable + True iff table has one variable is_empty: bool - True iff series has no variables or no instances + True iff table has no variables or no instances has_nans: bool - True iff the series contains NaN values + True iff the table contains NaN values n_instances: int number of instances/rows in the table n_features: int - number of variables in series + number of variables in table feature_names: list of int or object - names of variables in series + names of variables in table """ _tags = { @@ -672,17 +672,17 @@ class TablePolarsLazy(BaseTable): Parameters ---------- is_univariate: bool - True iff series has one variable + True iff table has one variable is_empty: bool - True iff series has no variables or no instances + True iff table has no variables or no instances has_nans: bool - True iff the series contains NaN values + True iff the table contains NaN values n_instances: int number of instances/rows in the table n_features: int - number of variables in series + number of variables in table feature_names: list of int or object - names of variables in series + names of variables in table """ _tags = { From d6903a041c4ee4f9cf0998298e552f7ff2d56868 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 22:30:37 +0100 Subject: [PATCH 15/40] fix tags --- skpro/datatypes/_table/_check.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/skpro/datatypes/_table/_check.py b/skpro/datatypes/_table/_check.py index 09b8c7043..f02cac429 100644 --- a/skpro/datatypes/_table/_check.py +++ b/skpro/datatypes/_table/_check.py @@ -67,6 +67,7 @@ class TablePdDataFrame(BaseTable): "python_dependencies": "pandas", "capability:multivariate": True, "capability:missing_values": True, + "capability:index": True, } def __init__( @@ -169,9 +170,9 @@ class TablePdSeries(BaseTable): "name_aliases": [], "python_version": None, "python_dependencies": "pandas", - "capability:multivariate": True, - "capability:unequally_spaced": True, + "capability:multivariate": False, "capability:missing_values": True, + "capability:index": True, } def __init__( @@ -282,6 +283,7 @@ class TableNp1D(BaseTable): "python_dependencies": "numpy", "capability:multivariate": False, "capability:missing_values": True, + "capability:index": False, } def __init__( @@ -387,6 +389,7 @@ class TableNp2D(BaseTable): "python_dependencies": "numpy", "capability:multivariate": True, "capability:missing_values": True, + "capability:index": False, } def __init__( @@ -491,6 +494,7 @@ class TableListOfDict(BaseTable): "python_dependencies": "numpy", "capability:multivariate": True, "capability:missing_values": True, + "capability:index": False, } def __init__( @@ -618,6 +622,7 @@ class TablePolarsEager(BaseTable): "python_dependencies": ["polars", "pyarrow"], "capability:multivariate": True, "capability:missing_values": True, + "capability:index": False, } def __init__( @@ -694,6 +699,7 @@ class TablePolarsLazy(BaseTable): "python_dependencies": ["polars", "pyarrow"], "capability:multivariate": True, "capability:missing_values": True, + "capability:index": False, } def __init__( From dbf08ef63a3533e5f59b478749d437ad9545bad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 23:24:40 +0100 Subject: [PATCH 16/40] simplify retrieval --- skpro/datatypes/_check.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index 6a3967b95..8dd6312b4 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -55,28 +55,13 @@ def generate_check_dict(): from skbase.utils.dependencies import _check_estimator_deps from skpro import datatypes + from skpro.utils.retrieval import _all_classes mod = datatypes - classes = [] - for _, name, _ in pkgutil.walk_packages(mod.__path__, prefix=mod.__name__ + "."): - submodule = importlib.import_module(name) - for _, obj in inspect.getmembers(submodule): - if inspect.isclass(obj): - if not obj.__name__.startswith("Base"): - classes.append(obj) + classes = _all_classes(mod) classes = [x for x in classes if issubclass(x, BaseDatatype) and x != BaseDatatype] - # this does not work, but should - bug in skbase? - # ROOT = str(Path(__file__).parent) # sktime package root directory - # - # result = all_objects( - # object_types=BaseDatatype, - # package_name="sktime.datatypes", - # path=ROOT, - # return_names=False, - # ) - # subset only to data types with soft dependencies present result = [x for x in classes if _check_estimator_deps(x, severity="none")] From 9d41a7f2b05c58de7f6be568f594ca1f34dc9844 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 23:25:27 +0100 Subject: [PATCH 17/40] Update _check.py --- skpro/datatypes/_check.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index 8dd6312b4..d0b43ef91 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -23,10 +23,6 @@ "mtype", ] -import importlib -import inspect -import pkgutil - import numpy as np from skpro.datatypes._base import BaseDatatype From adcad09bfb566263e9176f5dbe4f0426e17ac59f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 23:30:26 +0100 Subject: [PATCH 18/40] Update _check.py --- skpro/datatypes/_check.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index d0b43ef91..e1c8dae8e 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -50,13 +50,12 @@ def generate_check_dict(): """Generate check_dict using lookup.""" from skbase.utils.dependencies import _check_estimator_deps - from skpro import datatypes from skpro.utils.retrieval import _all_classes - mod = datatypes - - classes = _all_classes(mod) - classes = [x for x in classes if issubclass(x, BaseDatatype) and x != BaseDatatype] + classes = _all_classes("skpro.datatypes") + classes = [x[1] for x in classes] + classes = [x for x in classes if issubclass(x, BaseDatatype)] + classes = [x for x in classes if not x.__name__.startswith("Base")] # subset only to data types with soft dependencies present result = [x for x in classes if _check_estimator_deps(x, severity="none")] From 649f1ae0ef855725d0cf53869aa5dadc25bef88d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 14 Jul 2024 23:35:24 +0100 Subject: [PATCH 19/40] convert --- skpro/datatypes/_check.py | 2 +- skpro/datatypes/_convert.py | 47 ++++++++++++++++++++++++++++++++++--- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index e1c8dae8e..15ea30295 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -43,7 +43,7 @@ def get_check_dict(): """ if len(check_dict) == 0: check_dict.update(generate_check_dict()) - return check_dict + return check_dict.copy() def generate_check_dict(): diff --git a/skpro/datatypes/_convert.py b/skpro/datatypes/_convert.py index f823b47c4..50721dd14 100644 --- a/skpro/datatypes/_convert.py +++ b/skpro/datatypes/_convert.py @@ -69,15 +69,56 @@ import numpy as np import pandas as pd +from skpro.datatypes._base import BaseConverter from skpro.datatypes._check import mtype as infer_mtype from skpro.datatypes._proba import convert_dict_Proba from skpro.datatypes._registry import mtype_to_scitype from skpro.datatypes._table import convert_dict_Table # pool convert_dict-s and infer_mtype_dict-s -convert_dict = dict() -convert_dict.update(convert_dict_Table) -convert_dict.update(convert_dict_Proba) +convert_dict = {} + + +def get_convert_dict(): + """Retrieve convert_dict, caches the first time it is requested. + + This is to avoid repeated, time consuming crawling in generate_check_dict, + which would otherwise be called every time check_dict is requested. + + Leaving the code on root level will also fail, due to circular imports. + """ + if len(convert_dict) == 0: + convert_dict.update(generate_convert_dict()) + return convert_dict.copy() + + +def generate_convert_dict(): + """Generate convert_dict using lookup.""" + from skbase.utils.dependencies import _check_estimator_deps + + from skpro.utils.retrieval import _all_classes + + classes = _all_classes("skpro.datatypes") + classes = [x[1] for x in classes] + classes = [x for x in classes if issubclass(x, BaseConverter)] + classes = [x for x in classes if not x.__name__.startswith("Base")] + + # subset only to data types with soft dependencies present + result = [x for x in classes if _check_estimator_deps(x, severity="none")] + + check_dict = dict() + for k in result: + from_mtype = k.get_class_tag("from_mtype") + to_mtype = k.get_class_tag("to_mtype") + scitype = k.get_class_tag("scitype") + + convert_dict[(from_mtype, to_mtype, scitype)] = k()._convert + + # temporary while refactoring + check_dict.update(convert_dict_Proba) + check_dict.update(convert_dict_Table) + + return check_dict def convert( From d8e3130b003587a3d08fd06788a4aa301bab956d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Mon, 22 Jul 2024 21:59:47 +0100 Subject: [PATCH 20/40] Update _check.py --- skpro/datatypes/_proba/_check.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skpro/datatypes/_proba/_check.py b/skpro/datatypes/_proba/_check.py index 4355706c9..c6c1c063f 100644 --- a/skpro/datatypes/_proba/_check.py +++ b/skpro/datatypes/_proba/_check.py @@ -1,6 +1,6 @@ -"""Machine type checkers for Series scitype. +"""Machine type checkers for Proba (probabilistic return) scitype. -Exports checkers for Series scitype: +Exports checkers for Proba scitype: check_dict: dict indexed by pairs of str 1st element = mtype - str From 0acb1e03672ac0a3454a7eacfd7079f8f35e8a8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 00:05:31 +0100 Subject: [PATCH 21/40] use _get_key in type --- skpro/datatypes/_base/_base.py | 9 +++++++++ skpro/datatypes/_check.py | 6 ++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py index fe663c370..d2b962741 100644 --- a/skpro/datatypes/_base/_base.py +++ b/skpro/datatypes/_base/_base.py @@ -157,6 +157,15 @@ def get(self, key, default=None): """ return getattr(self, key, default) + @classmethod + def _get_key(cls): + """Get unique dictionary key corresponding to self. + + Private function, used in collecting a dictionary of checks. + """ + mtype = cls.get_class_tag("name") + scitype = cls.get_class_tag("scitype") + return (mtype, scitype) class BaseConverter(BaseObject): """Base class for data type converters. diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index 15ea30295..54b7a570e 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -62,10 +62,8 @@ def generate_check_dict(): check_dict = dict() for k in result: - mtype = k.get_class_tag("name") - scitype = k.get_class_tag("scitype") - - check_dict[(mtype, scitype)] = k()._check + key = k._get_key() + check_dict[key] = k()._check # temporary while refactoring check_dict.update(check_dict_Proba) From 2047e6bb37528dc66f42c839b60b69b483263c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 00:28:59 +0100 Subject: [PATCH 22/40] converter base --- skpro/datatypes/_base/_base.py | 67 ++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py index d2b962741..42a2847fd 100644 --- a/skpro/datatypes/_base/_base.py +++ b/skpro/datatypes/_base/_base.py @@ -175,16 +175,45 @@ class BaseConverter(BaseObject): _tags = { "object_type": "converter", - "scitype": None, - "mtype_from": None, # equal to name field - "mtype_to": None, # equal to name field + "mtype_from": None, # type to convert from - BaseDatatype class + "mtype_to": None, # type to convert to - BaseDatatype class + "multiple_conversions": False, # whether converter encodes multiple conversions "python_version": None, "python_dependencies": None, } - def __init__(self): + def __init__(self, mtype_from=None, mtype_to=None): + self.mtype_from = mtype_from + self.mtype_to = mtype_to super().__init__() + if mtype_from is not None: + self.set_tags(**{"mtype_from": mtype_from}) + if mtype_to is not None: + self.set_tags(**{"mtype_to": mtype_to}) + + mtype_from = self.get_class_tag("mtype_from") + mtype_to = self.get_class_tag("mtype_to") + + if mtype_from is None: + raise ValueError( + f"Error in instantiating {self.__class__.__name__}: " + "mtype_from and mtype_to must be set if the class has no defaults. " + "For valid pairs of defaults, use get_conversions." + ) + if mtype_to is None: + raise ValueError( + f"Error in instantiating {self.__class__.__name__}: " + "mtype_to must be set in constructor, as the class has no defaults. " + "For valid pairs of defaults, use get_conversions." + ) + if (mtype_from, mtype_to) not in self.get_conversions(): + raise ValueError( + f"Error in instantiating {self.__class__.__name__}: " + "mtype_from and mtype_to must be a valid pair of defaults. " + "For valid pairs of defaults, use get_conversions." + ) + def convert(self, obj, store=None): """Convert obj to another machine type. @@ -208,3 +237,33 @@ def _convert(self, obj, store=None): Reference of storage for lossy conversions. """ raise NotImplementedError + + @classmethod + def get_conversions(cls): + """Get all conversions. + + Returns + ------- + list of tuples (BaseDatatype subclass, BaseDatatype subclass) + List of all conversions in this class. + """ + cls_from = cls.get_class_tag("mtype_from") + cls_to = cls.get_class_tag("mtype_to") + + if cls_from is not None and cls_to is not None: + return [(cls_from, cls_to)] + # if multiple conversions are encoded, this should be overridden + raise NotImplementedError + + @classmethod + def _get_key(cls): + """Get unique dictionary key corresponding to self. + + Private function, used in collecting a dictionary of checks. + """ + cls_from = cls.get_class_tag("mtype_from") + cls_to = cls.get_class_tag("mtype_to") + mtype_from = cls_from.get_class_tag("name") + mtype_to = cls_to.get_class_tag("name") + scitype = cls_to.get_class_tag("scitype") + return (mtype_from, mtype_to, scitype) From 3df7025628429d934febce99bd5189d80f3fa17f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:08:22 +0100 Subject: [PATCH 23/40] base clses --- skpro/datatypes/_base/_base.py | 105 +++++++++++++++++++++++++++++---- 1 file changed, 95 insertions(+), 10 deletions(-) diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py index 42a2847fd..3e5f92511 100644 --- a/skpro/datatypes/_base/_base.py +++ b/skpro/datatypes/_base/_base.py @@ -27,6 +27,30 @@ class BaseDatatype(BaseObject): def __init__(self): super().__init__() + # call defaults to check + def __call__(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : instance of self only returned if return_metadata is True. + Metadata dictionary. + """ + return self.check(obj=obj, return_metadata=return_metadata, var_name=var_name) + def check(self, obj, return_metadata=False, var_name="obj"): """Check if obj is of this data type. @@ -157,14 +181,13 @@ def get(self, key, default=None): """ return getattr(self, key, default) - @classmethod - def _get_key(cls): + def _get_key(self): """Get unique dictionary key corresponding to self. Private function, used in collecting a dictionary of checks. """ - mtype = cls.get_class_tag("name") - scitype = cls.get_class_tag("scitype") + mtype = self.get_class_tag("name") + scitype = self.get_class_tag("scitype") return (mtype, scitype) class BaseConverter(BaseObject): @@ -175,8 +198,8 @@ class BaseConverter(BaseObject): _tags = { "object_type": "converter", - "mtype_from": None, # type to convert from - BaseDatatype class - "mtype_to": None, # type to convert to - BaseDatatype class + "mtype_from": None, # type to convert from - BaseDatatype class or str + "mtype_to": None, # type to convert to - BaseDatatype class or str "multiple_conversions": False, # whether converter encodes multiple conversions "python_version": None, "python_dependencies": None, @@ -214,6 +237,24 @@ def __init__(self, mtype_from=None, mtype_to=None): "For valid pairs of defaults, use get_conversions." ) + # call defaults to convert + def __call__(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + + Returns + ------- + converted_obj : any + Object obj converted to another machine type. + """ + return self.convert(obj=obj, store=store) + def convert(self, obj, store=None): """Convert obj to another machine type. @@ -255,15 +296,59 @@ def get_conversions(cls): # if multiple conversions are encoded, this should be overridden raise NotImplementedError - @classmethod - def _get_key(cls): + def _get_cls_from_to(self): + """Get classes from and to. + + Returns + ------- + cls_from : BaseDatatype subclass + Class to convert from. + cls_to : BaseDatatype subclass + Class to convert to. + """ + cls_from = self.get_class_tag("mtype_from") + cls_to = self.get_class_tag("mtype_to") + + cls_from = _coerce_str_to_cls(cls_from) + cls_to = _coerce_str_to_cls(cls_to) + + return cls_from, cls_to + + def _get_key(self): """Get unique dictionary key corresponding to self. Private function, used in collecting a dictionary of checks. """ - cls_from = cls.get_class_tag("mtype_from") - cls_to = cls.get_class_tag("mtype_to") + cls_from, cls_to = self._get_cls_from_to() + mtype_from = cls_from.get_class_tag("name") mtype_to = cls_to.get_class_tag("name") scitype = cls_to.get_class_tag("scitype") return (mtype_from, mtype_to, scitype) + + +def _coerce_str_to_cls(cls_or_str): + """Get class from string. + + Parameters + ---------- + cls_or_str : str or class + Class or string. If string, assumed to be a unique mtype string from + one of the BaseDatatype subclasses. + + Returns + ------- + cls : cls_or_str, if was class; otherwise, class corresponding to string. + """ + if not isinstance(cls_or_str, str): + return cls_or_str + + # otherwise, we use the string to get the class from the check dict + # perhaps it is nicer to transfer this to a registry later. + from skpro.datatypes._check import get_check_dict + + cd = get_check_dict() + cls = [cd[k].__class__ for k in cd if k[0] == cls_or_str] + if len(cls) != 1: + raise ValueError(f"Error in converting string to class: {cls_or_str}") + return cls[0] From 7f5bf901580c9fad289aa85e7117d533bc9a3486 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:08:32 +0100 Subject: [PATCH 24/40] cache checkers --- skpro/datatypes/_check.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index 54b7a570e..71c56343e 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -23,6 +23,8 @@ "mtype", ] +from functools import lru_cache + import numpy as np from skpro.datatypes._base import BaseDatatype @@ -30,22 +32,18 @@ from skpro.datatypes._proba import check_dict_Proba from skpro.datatypes._registry import AMBIGUOUS_MTYPES, SCITYPE_LIST, mtype_to_scitype -check_dict = {} - def get_check_dict(): """Retrieve check_dict, caches the first time it is requested. This is to avoid repeated, time consuming crawling in generate_check_dict, which would otherwise be called every time check_dict is requested. - - Leaving the code on root level will also fail, due to circular imports. """ - if len(check_dict) == 0: - check_dict.update(generate_check_dict()) + check_dict = generate_check_dict() return check_dict.copy() +@lru_cache(maxsize=1) def generate_check_dict(): """Generate check_dict using lookup.""" from skbase.utils.dependencies import _check_estimator_deps @@ -61,9 +59,10 @@ def generate_check_dict(): result = [x for x in classes if _check_estimator_deps(x, severity="none")] check_dict = dict() - for k in result: + for cls in result: + k = cls() key = k._get_key() - check_dict[key] = k()._check + check_dict[key] = k # temporary while refactoring check_dict.update(check_dict_Proba) @@ -190,6 +189,7 @@ def check_is_mtype( """ mtype = _coerce_list_of_str(mtype, var_name="mtype") + check_dict = get_check_dict() valid_keys = check_dict.keys() # we loop through individual mtypes in mtype and see whether they pass the check @@ -332,6 +332,7 @@ def mtype(obj, as_scitype=None, exclude_mtypes=AMBIGUOUS_MTYPES): for scitype in as_scitype: _check_scitype_valid(scitype) + check_dict = get_check_dict() m_plus_scitypes = [ (x[0], x[1]) for x in check_dict.keys() if x[0] not in exclude_mtypes ] @@ -441,6 +442,7 @@ def check_is_scitype( for x in scitype: _check_scitype_valid(x) + check_dict = get_check_dict() valid_keys = check_dict.keys() # find all the mtype keys corresponding to the scitypes From 6c24d0847a8b5750f835ca07fad702ce33d6c4d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:08:45 +0100 Subject: [PATCH 25/40] converters - experimental --- skpro/datatypes/_convert.py | 31 ++++++++---- skpro/datatypes/_table/_convert.py | 77 +++++++++++++++++++++++++----- 2 files changed, 87 insertions(+), 21 deletions(-) diff --git a/skpro/datatypes/_convert.py b/skpro/datatypes/_convert.py index 50721dd14..4079eb7e1 100644 --- a/skpro/datatypes/_convert.py +++ b/skpro/datatypes/_convert.py @@ -65,6 +65,7 @@ ] from copy import deepcopy +from functools import lru_cache import numpy as np import pandas as pd @@ -84,14 +85,12 @@ def get_convert_dict(): This is to avoid repeated, time consuming crawling in generate_check_dict, which would otherwise be called every time check_dict is requested. - - Leaving the code on root level will also fail, due to circular imports. """ - if len(convert_dict) == 0: - convert_dict.update(generate_convert_dict()) + convert_dict = generate_convert_dict() return convert_dict.copy() +@lru_cache(maxsize=1) def generate_convert_dict(): """Generate convert_dict using lookup.""" from skbase.utils.dependencies import _check_estimator_deps @@ -107,12 +106,24 @@ def generate_convert_dict(): result = [x for x in classes if _check_estimator_deps(x, severity="none")] check_dict = dict() - for k in result: - from_mtype = k.get_class_tag("from_mtype") - to_mtype = k.get_class_tag("to_mtype") - scitype = k.get_class_tag("scitype") - - convert_dict[(from_mtype, to_mtype, scitype)] = k()._convert + for cls in result: + if not cls.get_class_tag("multiple_conversions", False): + k = cls() + key = k._get_key() + convert_dict[key] = k + else: + for cls_to_cls in k.get_conversions(): + k = k(*cls_to_cls) + + # check dependencies for both classes + # only add conversions if dependencies are satisfied for to and from + cls_from, cls_to = k._get_cls_from_to() + from_dep_chk = _check_estimator_deps(cls_from, severity="none") + to_dep_chk = _check_estimator_deps(cls_to, severity="none") + + if from_dep_chk and to_dep_chk: + key = k._get_key() + convert_dict[key] = k # temporary while refactoring check_dict.update(convert_dict_Proba) diff --git a/skpro/datatypes/_table/_convert.py b/skpro/datatypes/_table/_convert.py index 6fed2231b..d88db021e 100644 --- a/skpro/datatypes/_table/_convert.py +++ b/skpro/datatypes/_table/_convert.py @@ -33,8 +33,8 @@ import numpy as np import pandas as pd +from skpro.datatypes._table._base import BaseConverter from skpro.datatypes._convert_utils._convert import _extend_conversions -from skpro.datatypes._table._registry import MTYPE_LIST_TABLE from skpro.utils.validation._dependencies import _check_soft_dependencies ############################################################## @@ -44,13 +44,71 @@ convert_dict = dict() -def convert_identity(obj, store=None): - return obj - - -# assign identity function to type conversion to self -for tp in MTYPE_LIST_TABLE: - convert_dict[(tp, tp, "Table")] = convert_identity +class TableIdentity(BaseConverter): + """All Table scitype conversions of any mtype to itself. + + This is the identity conversion for Table scitype, + no coercion is done, the object is returned as is. + """ + + _tags = { + "object_type": "converter", + "mtype_from": None, + "mtype_to": None, + "multiple_conversions": True, + "python_version": None, + "python_dependencies": None, + } + + @classmethod + def get_conversions(cls): + """Get all conversions. + + Returns + ------- + list of tuples (BaseDatatype subclass, BaseDatatype subclass) + List of all conversions in this class. + """ + from skpro.datatypes._table._registry import MTYPE_LIST_TABLE + + return [(tp, tp) for tp in MTYPE_LIST_TABLE] + + # identity conversion + def _convert(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + """ + return obj + +class Numpy1dToNumpy2D(BaseConverter): + """Conversion: numpy1D -> numpy2D, of Table scitype.""" + + _tags = { + "object_type": "converter", + "mtype_from": "numpy1D", # type to convert from - BaseDatatype class or str + "mtype_to": "numpy2D", # type to convert to - BaseDatatype class or str + "multiple_conversions": False, # whether converter encodes multiple conversions + "python_version": None, + "python_dependencies": None, + } + + def _convert(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + """ + return convert_1D_to_2D_numpy_as_Table(obj=obj, store=store) def convert_1D_to_2D_numpy_as_Table(obj: np.ndarray, store=None) -> np.ndarray: @@ -65,9 +123,6 @@ def convert_1D_to_2D_numpy_as_Table(obj: np.ndarray, store=None) -> np.ndarray: return res -convert_dict[("numpy1D", "numpy2D", "Table")] = convert_1D_to_2D_numpy_as_Table - - def convert_2D_to_1D_numpy_as_Table(obj: np.ndarray, store=None) -> np.ndarray: if not isinstance(obj, np.ndarray): raise TypeError("input must be a np.ndarray") From 47fbb6cee7af7e7659fdce113eeda83044537a71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:09:13 +0100 Subject: [PATCH 26/40] Update _convert.py --- skpro/datatypes/_table/_convert.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/skpro/datatypes/_table/_convert.py b/skpro/datatypes/_table/_convert.py index d88db021e..c2b3ae1ab 100644 --- a/skpro/datatypes/_table/_convert.py +++ b/skpro/datatypes/_table/_convert.py @@ -33,8 +33,9 @@ import numpy as np import pandas as pd -from skpro.datatypes._table._base import BaseConverter from skpro.datatypes._convert_utils._convert import _extend_conversions +from skpro.datatypes._table._base import BaseConverter +from skpro.datatypes._table._registry import MTYPE_LIST_TABLE from skpro.utils.validation._dependencies import _check_soft_dependencies ############################################################## @@ -69,8 +70,6 @@ def get_conversions(cls): list of tuples (BaseDatatype subclass, BaseDatatype subclass) List of all conversions in this class. """ - from skpro.datatypes._table._registry import MTYPE_LIST_TABLE - return [(tp, tp) for tp in MTYPE_LIST_TABLE] # identity conversion From a2b91cd78b01cfc8e5e65aa769c31a4638d1811d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:13:33 +0100 Subject: [PATCH 27/40] linting --- skpro/datatypes/_base/_base.py | 1 + skpro/datatypes/_table/_convert.py | 1 + 2 files changed, 2 insertions(+) diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py index 3e5f92511..09870e906 100644 --- a/skpro/datatypes/_base/_base.py +++ b/skpro/datatypes/_base/_base.py @@ -190,6 +190,7 @@ def _get_key(self): scitype = self.get_class_tag("scitype") return (mtype, scitype) + class BaseConverter(BaseObject): """Base class for data type converters. diff --git a/skpro/datatypes/_table/_convert.py b/skpro/datatypes/_table/_convert.py index c2b3ae1ab..e1fdc3ee0 100644 --- a/skpro/datatypes/_table/_convert.py +++ b/skpro/datatypes/_table/_convert.py @@ -85,6 +85,7 @@ def _convert(self, obj, store=None): """ return obj + class Numpy1dToNumpy2D(BaseConverter): """Conversion: numpy1D -> numpy2D, of Table scitype.""" From 97234e905279b8c065e614b6f73e777d6a4845d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:15:53 +0100 Subject: [PATCH 28/40] Update _convert.py --- skpro/datatypes/_table/_convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/datatypes/_table/_convert.py b/skpro/datatypes/_table/_convert.py index e1fdc3ee0..3eac8a047 100644 --- a/skpro/datatypes/_table/_convert.py +++ b/skpro/datatypes/_table/_convert.py @@ -33,8 +33,8 @@ import numpy as np import pandas as pd +from skpro.datatypes._base import BaseConverter from skpro.datatypes._convert_utils._convert import _extend_conversions -from skpro.datatypes._table._base import BaseConverter from skpro.datatypes._table._registry import MTYPE_LIST_TABLE from skpro.utils.validation._dependencies import _check_soft_dependencies From 304819eae0ff209af65267b500fb41d23766dac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:23:31 +0100 Subject: [PATCH 29/40] Update _base.py --- skpro/datatypes/_base/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py index 09870e906..d72325998 100644 --- a/skpro/datatypes/_base/_base.py +++ b/skpro/datatypes/_base/_base.py @@ -49,7 +49,7 @@ def __call__(self, obj, return_metadata=False, var_name="obj"): metadata : instance of self only returned if return_metadata is True. Metadata dictionary. """ - return self.check(obj=obj, return_metadata=return_metadata, var_name=var_name) + return self._check(obj=obj, return_metadata=return_metadata, var_name=var_name) def check(self, obj, return_metadata=False, var_name="obj"): """Check if obj is of this data type. From 0faf0b8b498fe74bc3563d6068c5614f6e4dff0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:39:21 +0100 Subject: [PATCH 30/40] bugfixes --- skpro/datatypes/_base/_base.py | 14 ++++++++------ skpro/datatypes/_convert.py | 32 ++++++++++++++++++-------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py index d72325998..40ca2e8f3 100644 --- a/skpro/datatypes/_base/_base.py +++ b/skpro/datatypes/_base/_base.py @@ -216,8 +216,8 @@ def __init__(self, mtype_from=None, mtype_to=None): if mtype_to is not None: self.set_tags(**{"mtype_to": mtype_to}) - mtype_from = self.get_class_tag("mtype_from") - mtype_to = self.get_class_tag("mtype_to") + mtype_from = self.get_tag("mtype_from") + mtype_to = self.get_tag("mtype_to") if mtype_from is None: raise ValueError( @@ -231,7 +231,7 @@ def __init__(self, mtype_from=None, mtype_to=None): "mtype_to must be set in constructor, as the class has no defaults. " "For valid pairs of defaults, use get_conversions." ) - if (mtype_from, mtype_to) not in self.get_conversions(): + if (mtype_from, mtype_to) not in self.__class__.get_conversions(): raise ValueError( f"Error in instantiating {self.__class__.__name__}: " "mtype_from and mtype_to must be a valid pair of defaults. " @@ -307,8 +307,8 @@ def _get_cls_from_to(self): cls_to : BaseDatatype subclass Class to convert to. """ - cls_from = self.get_class_tag("mtype_from") - cls_to = self.get_class_tag("mtype_to") + cls_from = self.get_tag("mtype_from") + cls_to = self.get_tag("mtype_to") cls_from = _coerce_str_to_cls(cls_from) cls_to = _coerce_str_to_cls(cls_to) @@ -350,6 +350,8 @@ def _coerce_str_to_cls(cls_or_str): cd = get_check_dict() cls = [cd[k].__class__ for k in cd if k[0] == cls_or_str] - if len(cls) != 1: + if len(cls) > 1: raise ValueError(f"Error in converting string to class: {cls_or_str}") + elif len(cls) < 1: + return None return cls[0] diff --git a/skpro/datatypes/_convert.py b/skpro/datatypes/_convert.py index 4079eb7e1..b5c1dad7a 100644 --- a/skpro/datatypes/_convert.py +++ b/skpro/datatypes/_convert.py @@ -76,9 +76,6 @@ from skpro.datatypes._registry import mtype_to_scitype from skpro.datatypes._table import convert_dict_Table -# pool convert_dict-s and infer_mtype_dict-s -convert_dict = {} - def get_convert_dict(): """Retrieve convert_dict, caches the first time it is requested. @@ -105,31 +102,36 @@ def generate_convert_dict(): # subset only to data types with soft dependencies present result = [x for x in classes if _check_estimator_deps(x, severity="none")] - check_dict = dict() + convert_dict = dict() for cls in result: if not cls.get_class_tag("multiple_conversions", False): k = cls() key = k._get_key() convert_dict[key] = k else: - for cls_to_cls in k.get_conversions(): - k = k(*cls_to_cls) + for cls_to_cls in cls.get_conversions(): + k = cls(*cls_to_cls) # check dependencies for both classes # only add conversions if dependencies are satisfied for to and from cls_from, cls_to = k._get_cls_from_to() - from_dep_chk = _check_estimator_deps(cls_from, severity="none") - to_dep_chk = _check_estimator_deps(cls_to, severity="none") - if from_dep_chk and to_dep_chk: - key = k._get_key() - convert_dict[key] = k + # do not add conversion if dependencies are not satisfied + if cls_from is None or cls_to is None: + continue + if not _check_estimator_deps(cls_from, severity="none"): + continue + if not _check_estimator_deps(cls_to, severity="none"): + continue + + key = k._get_key() + convert_dict[key] = k # temporary while refactoring - check_dict.update(convert_dict_Proba) - check_dict.update(convert_dict_Table) + convert_dict.update(convert_dict_Proba) + convert_dict.update(convert_dict_Table) - return check_dict + return convert_dict def convert( @@ -195,6 +197,7 @@ def convert( key = (from_type, to_type, as_scitype) + convert_dict = get_convert_dict() if key not in convert_dict.keys(): raise NotImplementedError( "no conversion defined from type " + str(from_type) + " to " + str(to_type) @@ -331,6 +334,7 @@ def _conversions_defined(scitype: str): entry of row i, col j is 1 if conversion from i to j is defined, 0 if conversion from i to j is not defined """ + convert_dict = get_convert_dict() pairs = [(x[0], x[1]) for x in list(convert_dict.keys()) if x[2] == scitype] cols0 = {x[0] for x in list(convert_dict.keys()) if x[2] == scitype} cols1 = {x[1] for x in list(convert_dict.keys()) if x[2] == scitype} From 710fb79e435807621c87cf777e3ec83b27dfb40a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 01:43:30 +0100 Subject: [PATCH 31/40] tests --- skpro/datatypes/tests/test_check.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/skpro/datatypes/tests/test_check.py b/skpro/datatypes/tests/test_check.py index 5fdc0d1cb..ddb916e97 100644 --- a/skpro/datatypes/tests/test_check.py +++ b/skpro/datatypes/tests/test_check.py @@ -6,9 +6,9 @@ from skpro.datatypes._check import ( AMBIGUOUS_MTYPES, - check_dict, check_is_mtype, check_is_scitype, + get_check_dict, ) from skpro.datatypes._check import mtype as infer_mtype from skpro.datatypes._check import scitype as infer_scitype @@ -123,6 +123,7 @@ def test_check_positive(scitype, mtype, fixture_index): fixture = get_examples(mtype=mtype, as_scitype=scitype).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist, when full metadata is queried @@ -174,6 +175,7 @@ def test_check_positive_check_scitype(scitype, mtype, fixture_index): fixture = get_examples(mtype=mtype, as_scitype=scitype).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist, when full metadata is queried @@ -222,6 +224,7 @@ def test_check_metadata_inference(scitype, mtype, fixture_index): ).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # if the examples have no metadata to them, don't test metadata_provided = expected_metadata is not None @@ -340,6 +343,7 @@ def test_check_negative(scitype, mtype): fixture_wrong_type = fixtures[wrong_mtype].get(i) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist @@ -392,6 +396,7 @@ def test_mtype_infer(scitype, mtype, fixture_index): fixture = get_examples(mtype=mtype, as_scitype=scitype).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist @@ -440,6 +445,7 @@ def test_scitype_infer(scitype, mtype, fixture_index): fixture = get_examples(mtype=mtype, as_scitype=scitype).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist From 05fc2100513d953af5516c769a8fefc53b13caf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 12:25:44 +0100 Subject: [PATCH 32/40] handle soft deps --- skpro/datatypes/_registry.py | 5 ++++- skpro/datatypes/tests/test_convert.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/skpro/datatypes/_registry.py b/skpro/datatypes/_registry.py index bc0c36c36..4da55b86a 100644 --- a/skpro/datatypes/_registry.py +++ b/skpro/datatypes/_registry.py @@ -46,7 +46,10 @@ MTYPE_REGISTER += MTYPE_REGISTER_TABLE MTYPE_REGISTER += MTYPE_REGISTER_PROBA -MTYPE_SOFT_DEPS = {} +MTYPE_SOFT_DEPS = { + "polars_eager_table": "polars", + "polars_lazy_table": "polars", +} # mtypes to exclude in checking since they are ambiguous and rare diff --git a/skpro/datatypes/tests/test_convert.py b/skpro/datatypes/tests/test_convert.py index 6a0f82af6..91632ecf3 100644 --- a/skpro/datatypes/tests/test_convert.py +++ b/skpro/datatypes/tests/test_convert.py @@ -26,7 +26,7 @@ def _generate_fixture_tuples(): conv_mat = _conversions_defined(scitype) - mtypes = scitype_to_mtype(scitype, softdeps="exclude") + mtypes = scitype_to_mtype(scitype, softdeps="present") if len(mtypes) == 0: # if there are no mtypes, this must have been reached by mistake/bug From 84410d9160b78ba679c5aeceb53b2b957335f66c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 18:24:11 +0100 Subject: [PATCH 33/40] handle softdeps --- skpro/datatypes/_convert.py | 24 +++++++++++++++++------- skpro/datatypes/tests/test_convert.py | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/skpro/datatypes/_convert.py b/skpro/datatypes/_convert.py index b5c1dad7a..9b5c69f6a 100644 --- a/skpro/datatypes/_convert.py +++ b/skpro/datatypes/_convert.py @@ -77,18 +77,24 @@ from skpro.datatypes._table import convert_dict_Table -def get_convert_dict(): +def get_convert_dict(soft_deps="present"): """Retrieve convert_dict, caches the first time it is requested. This is to avoid repeated, time consuming crawling in generate_check_dict, which would otherwise be called every time check_dict is requested. + + Parameters + ---------- + soft_deps : str, optional - one of "present", "all" + "present" - only conversions with soft dependencies present are included + "all" - all conversions are included """ - convert_dict = generate_convert_dict() + convert_dict = generate_convert_dict(soft_deps=soft_deps) return convert_dict.copy() @lru_cache(maxsize=1) -def generate_convert_dict(): +def generate_convert_dict(soft_deps="present"): """Generate convert_dict using lookup.""" from skbase.utils.dependencies import _check_estimator_deps @@ -119,9 +125,10 @@ def generate_convert_dict(): # do not add conversion if dependencies are not satisfied if cls_from is None or cls_to is None: continue - if not _check_estimator_deps(cls_from, severity="none"): + filter_sd = soft_deps in ["present"] + if filter_sd and not _check_estimator_deps(cls_from, severity="none"): continue - if not _check_estimator_deps(cls_to, severity="none"): + if filter_sd and not _check_estimator_deps(cls_to, severity="none"): continue key = k._get_key() @@ -320,13 +327,16 @@ def convert_to( return converted_obj -def _conversions_defined(scitype: str): +def _conversions_defined(scitype: str, soft_deps: str = "present"): """Return an indicator matrix which conversions are defined for scitype. Parameters ---------- scitype: str - name of scitype for which conversions are queried valid scitype strings, with explanation, are in datatypes.SCITYPE_REGISTER + soft_deps : str, optional - one of "present", "all" + "present" - only conversions with soft dependencies present are included + "all" - all conversions are included Returns ------- @@ -334,7 +344,7 @@ def _conversions_defined(scitype: str): entry of row i, col j is 1 if conversion from i to j is defined, 0 if conversion from i to j is not defined """ - convert_dict = get_convert_dict() + convert_dict = get_convert_dict(soft_deps=soft_deps) pairs = [(x[0], x[1]) for x in list(convert_dict.keys()) if x[2] == scitype] cols0 = {x[0] for x in list(convert_dict.keys()) if x[2] == scitype} cols1 = {x[1] for x in list(convert_dict.keys()) if x[2] == scitype} diff --git a/skpro/datatypes/tests/test_convert.py b/skpro/datatypes/tests/test_convert.py index 91632ecf3..a783eb390 100644 --- a/skpro/datatypes/tests/test_convert.py +++ b/skpro/datatypes/tests/test_convert.py @@ -24,7 +24,7 @@ def _generate_fixture_tuples(): if scitype in SCITYPES_NO_CONVERSIONS: continue - conv_mat = _conversions_defined(scitype) + conv_mat = _conversions_defined(scitype, soft_deps="present") mtypes = scitype_to_mtype(scitype, softdeps="present") From 7e6025155e06245eba4bb3e50d18a7fdf56d7373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 18:53:02 +0100 Subject: [PATCH 34/40] manage soft deps --- skpro/datatypes/_base/_base.py | 2 +- skpro/datatypes/_check.py | 20 ++++++++++++++++---- skpro/datatypes/_convert.py | 5 +++++ skpro/datatypes/tests/test_convert.py | 2 +- 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py index 40ca2e8f3..65d36e3e8 100644 --- a/skpro/datatypes/_base/_base.py +++ b/skpro/datatypes/_base/_base.py @@ -348,7 +348,7 @@ def _coerce_str_to_cls(cls_or_str): # perhaps it is nicer to transfer this to a registry later. from skpro.datatypes._check import get_check_dict - cd = get_check_dict() + cd = get_check_dict(soft_deps="all") cls = [cd[k].__class__ for k in cd if k[0] == cls_or_str] if len(cls) > 1: raise ValueError(f"Error in converting string to class: {cls_or_str}") diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index 71c56343e..bfe05c8cd 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -33,18 +33,29 @@ from skpro.datatypes._registry import AMBIGUOUS_MTYPES, SCITYPE_LIST, mtype_to_scitype -def get_check_dict(): +def get_check_dict(soft_deps="present"): """Retrieve check_dict, caches the first time it is requested. This is to avoid repeated, time consuming crawling in generate_check_dict, which would otherwise be called every time check_dict is requested. + + Parameters + ---------- + soft_deps : str, optional - one of "present", "all" + "present" - only checks with soft dependencies present are included + "all" - all checks are included """ - check_dict = generate_check_dict() + if soft_deps not in ["present", "all"]: + raise ValueError( + "Error in get_check_dict, soft_deps argument must be 'present' or 'all', " + f"found {soft_deps}" + ) + check_dict = generate_check_dict(soft_deps=soft_deps) return check_dict.copy() @lru_cache(maxsize=1) -def generate_check_dict(): +def generate_check_dict(soft_deps="present"): """Generate check_dict using lookup.""" from skbase.utils.dependencies import _check_estimator_deps @@ -56,7 +67,8 @@ def generate_check_dict(): classes = [x for x in classes if not x.__name__.startswith("Base")] # subset only to data types with soft dependencies present - result = [x for x in classes if _check_estimator_deps(x, severity="none")] + if soft_deps == "present": + result = [x for x in classes if _check_estimator_deps(x, severity="none")] check_dict = dict() for cls in result: diff --git a/skpro/datatypes/_convert.py b/skpro/datatypes/_convert.py index 9b5c69f6a..d9a92ead9 100644 --- a/skpro/datatypes/_convert.py +++ b/skpro/datatypes/_convert.py @@ -89,6 +89,11 @@ def get_convert_dict(soft_deps="present"): "present" - only conversions with soft dependencies present are included "all" - all conversions are included """ + if soft_deps not in ["present", "all"]: + raise ValueError( + "Error in get_check_dict, soft_deps argument must be 'present' or 'all', " + f"found {soft_deps}" + ) convert_dict = generate_convert_dict(soft_deps=soft_deps) return convert_dict.copy() diff --git a/skpro/datatypes/tests/test_convert.py b/skpro/datatypes/tests/test_convert.py index a783eb390..5acb3fdc4 100644 --- a/skpro/datatypes/tests/test_convert.py +++ b/skpro/datatypes/tests/test_convert.py @@ -24,7 +24,7 @@ def _generate_fixture_tuples(): if scitype in SCITYPES_NO_CONVERSIONS: continue - conv_mat = _conversions_defined(scitype, soft_deps="present") + conv_mat = _conversions_defined(scitype, soft_deps="all") mtypes = scitype_to_mtype(scitype, softdeps="present") From 33df86450dbff4d6942ea678f5e34a01d82e27cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 25 Jul 2024 22:58:13 +0100 Subject: [PATCH 35/40] bugfix --- skpro/datatypes/_check.py | 4 ++-- skpro/datatypes/_convert.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index bfe05c8cd..f36723dbd 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -68,10 +68,10 @@ def generate_check_dict(soft_deps="present"): # subset only to data types with soft dependencies present if soft_deps == "present": - result = [x for x in classes if _check_estimator_deps(x, severity="none")] + classes = [x for x in classes if _check_estimator_deps(x, severity="none")] check_dict = dict() - for cls in result: + for cls in classes: k = cls() key = k._get_key() check_dict[key] = k diff --git a/skpro/datatypes/_convert.py b/skpro/datatypes/_convert.py index d9a92ead9..6883b99e0 100644 --- a/skpro/datatypes/_convert.py +++ b/skpro/datatypes/_convert.py @@ -111,10 +111,11 @@ def generate_convert_dict(soft_deps="present"): classes = [x for x in classes if not x.__name__.startswith("Base")] # subset only to data types with soft dependencies present - result = [x for x in classes if _check_estimator_deps(x, severity="none")] + if soft_deps == "present": + classes = [x for x in classes if _check_estimator_deps(x, severity="none")] convert_dict = dict() - for cls in result: + for cls in classes: if not cls.get_class_tag("multiple_conversions", False): k = cls() key = k._get_key() From 39e6223ec79e40cd797a0fad8550fd79d6b221ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Mon, 12 Aug 2024 11:06:44 +0100 Subject: [PATCH 36/40] Update _check.py --- skpro/datatypes/_table/_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/datatypes/_table/_check.py b/skpro/datatypes/_table/_check.py index f02cac429..290c1f94d 100644 --- a/skpro/datatypes/_table/_check.py +++ b/skpro/datatypes/_table/_check.py @@ -1,4 +1,4 @@ -"""Machine type cclasses for Table scitype. +"""Machine type classes for Table scitype. Checks for each class are defined in the "check" method, of signature: From 698774b6c53e611e1f338d360de1a13fceb09153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 7 Sep 2024 14:54:12 +0100 Subject: [PATCH 37/40] revert test_polars changes --- skpro/datatypes/tests/test_polars.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/skpro/datatypes/tests/test_polars.py b/skpro/datatypes/tests/test_polars.py index 26bfe146f..27b425aa8 100644 --- a/skpro/datatypes/tests/test_polars.py +++ b/skpro/datatypes/tests/test_polars.py @@ -11,7 +11,7 @@ if _check_soft_dependencies(["polars", "pyarrow"], severity="none"): import polars as pl - from skpro.datatypes._adapter.polars import check_polars_frame + from skpro.datatypes._table._check import check_polars_table from skpro.datatypes._table._convert import convert_pandas_to_polars_eager TEST_ALPHAS = [0.05, 0.1, 0.25] @@ -83,9 +83,10 @@ def test_polars_eager_conversion_methods( X_train, X_test, y_train = polars_load_diabetes_pandas X_train_pl, X_test_pl, y_train_pl = polars_load_diabetes_polars - assert check_polars_frame(X_train_pl) - assert check_polars_frame(X_test_pl) - assert check_polars_frame(y_train_pl) + assert check_polars_table(X_train_pl) + assert check_polars_table(X_test_pl) + assert check_polars_table(y_train_pl) + assert (X_train.values == X_train_pl.to_numpy()).all() assert (X_test.values == X_test_pl.to_numpy()).all() assert (y_train.values == y_train_pl.to_numpy()).all() From 2a9b2b5e9d7ad76b922b6564043edc2bf6fe0968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 7 Sep 2024 16:59:47 +0100 Subject: [PATCH 38/40] Update test_polars.py --- skpro/datatypes/tests/test_polars.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/skpro/datatypes/tests/test_polars.py b/skpro/datatypes/tests/test_polars.py index 27b425aa8..643ac3f07 100644 --- a/skpro/datatypes/tests/test_polars.py +++ b/skpro/datatypes/tests/test_polars.py @@ -11,8 +11,7 @@ if _check_soft_dependencies(["polars", "pyarrow"], severity="none"): import polars as pl - from skpro.datatypes._table._check import check_polars_table - from skpro.datatypes._table._convert import convert_pandas_to_polars_eager + from skpro.datatypes import check_is_mtype, convert TEST_ALPHAS = [0.05, 0.1, 0.25] @@ -43,12 +42,15 @@ def estimator(): return _estimator +def _pd_to_pl(df): + return convert(df, from_type="pd_Series_Table", to_type="polars_eager_table") + @pytest.fixture def polars_load_diabetes_polars(polars_load_diabetes_pandas): X_train, X_test, y_train = polars_load_diabetes_pandas - X_train_pl = convert_pandas_to_polars_eager(X_train) - X_test_pl = convert_pandas_to_polars_eager(X_test) - y_train_pl = convert_pandas_to_polars_eager(y_train) + X_train_pl = _pd_to_pl(X_train) + X_test_pl = _pd_to_pl(X_test) + y_train_pl = _pd_to_pl(y_train) # drop the index in the polars frame X_train_pl = X_train_pl.drop(["__index__"]) @@ -60,9 +62,9 @@ def polars_load_diabetes_polars(polars_load_diabetes_pandas): def polars_load_diabetes_polars_with_index(polars_load_diabetes_pandas): X_train, X_test, y_train = polars_load_diabetes_pandas - X_train_pl = convert_pandas_to_polars_eager(X_train) - X_test_pl = convert_pandas_to_polars_eager(X_test) - y_train_pl = convert_pandas_to_polars_eager(y_train) + X_train_pl = _pd_to_pl(X_train) + X_test_pl = _pd_to_pl(X_test) + y_train_pl = _pd_to_pl(y_train) return [X_train_pl, X_test_pl, y_train_pl] @@ -83,9 +85,9 @@ def test_polars_eager_conversion_methods( X_train, X_test, y_train = polars_load_diabetes_pandas X_train_pl, X_test_pl, y_train_pl = polars_load_diabetes_polars - assert check_polars_table(X_train_pl) - assert check_polars_table(X_test_pl) - assert check_polars_table(y_train_pl) + assert check_is_mtype(X_train_pl, "polars_eager_table") + assert check_is_mtype(X_test_pl, "polars_eager_table") + assert check_is_mtype(y_train_pl, "polars_eager_table") assert (X_train.values == X_train_pl.to_numpy()).all() assert (X_test.values == X_test_pl.to_numpy()).all() From f317b8234db7e86a6ded632de45f1cdf4566c78c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 7 Sep 2024 17:36:56 +0100 Subject: [PATCH 39/40] Update test_polars.py --- skpro/datatypes/tests/test_polars.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skpro/datatypes/tests/test_polars.py b/skpro/datatypes/tests/test_polars.py index 643ac3f07..fcec3c92d 100644 --- a/skpro/datatypes/tests/test_polars.py +++ b/skpro/datatypes/tests/test_polars.py @@ -45,6 +45,7 @@ def estimator(): def _pd_to_pl(df): return convert(df, from_type="pd_Series_Table", to_type="polars_eager_table") + @pytest.fixture def polars_load_diabetes_polars(polars_load_diabetes_pandas): X_train, X_test, y_train = polars_load_diabetes_pandas From f60aee2463ad7feb497eb18d65e4c3eec3f17c92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 7 Sep 2024 17:58:08 +0100 Subject: [PATCH 40/40] Update test_polars.py --- skpro/datatypes/tests/test_polars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skpro/datatypes/tests/test_polars.py b/skpro/datatypes/tests/test_polars.py index fcec3c92d..55b5ed573 100644 --- a/skpro/datatypes/tests/test_polars.py +++ b/skpro/datatypes/tests/test_polars.py @@ -43,7 +43,7 @@ def estimator(): def _pd_to_pl(df): - return convert(df, from_type="pd_Series_Table", to_type="polars_eager_table") + return convert(df, from_type="pd_DataFrame_Table", to_type="polars_eager_table") @pytest.fixture