Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] refactor datatypes mtypes - checkers, converters #392

Merged
merged 45 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7dcf9ac
Create retrieval.py
fkiraly Jun 16, 2024
f95d757
Merge branch 'main' into mtype-refactor
fkiraly Jun 23, 2024
3657d94
base cls
fkiraly Jun 23, 2024
b35a1fa
base table
fkiraly Jun 23, 2024
00a57b9
docstr
fkiraly Jun 23, 2024
df00cfd
Merge branch 'main' into mtype-refactor
fkiraly Jun 23, 2024
f0f80fe
Revert "[MNT] increase `numpy` bound to `numpy < 2.1`, `numpy 2` comp…
fkiraly Jun 23, 2024
cb3ae05
Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `numpy 2` com…
fkiraly Jun 23, 2024
449b579
Revert "Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `nump…
fkiraly Jun 23, 2024
e16e87a
Reapply "Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `num…
fkiraly Jun 23, 2024
77dff4c
Merge branch 'main' into mtype-refactor
fkiraly Jul 14, 2024
aa68bb8
continuing refactor
fkiraly Jul 14, 2024
64acb3e
Revert "continuing refactor"
fkiraly Jul 14, 2024
5a05ca2
Reapply "continuing refactor"
fkiraly Jul 14, 2024
3b6cdb9
linting
fkiraly Jul 14, 2024
65a89fb
remove imports
fkiraly Jul 14, 2024
b1f1abd
docstrings
fkiraly Jul 14, 2024
d6903a0
fix tags
fkiraly Jul 14, 2024
dbf08ef
simplify retrieval
fkiraly Jul 14, 2024
9d41a7f
Update _check.py
fkiraly Jul 14, 2024
adcad09
Update _check.py
fkiraly Jul 14, 2024
649f1ae
convert
fkiraly Jul 14, 2024
d8e3130
Update _check.py
fkiraly Jul 22, 2024
0acb1e0
use _get_key in type
fkiraly Jul 24, 2024
2047e6b
converter base
fkiraly Jul 24, 2024
3df7025
base clses
fkiraly Jul 25, 2024
7f5bf90
cache checkers
fkiraly Jul 25, 2024
6c24d08
converters - experimental
fkiraly Jul 25, 2024
47fbb6c
Update _convert.py
fkiraly Jul 25, 2024
a2b91cd
linting
fkiraly Jul 25, 2024
97234e9
Update _convert.py
fkiraly Jul 25, 2024
304819e
Update _base.py
fkiraly Jul 25, 2024
0faf0b8
bugfixes
fkiraly Jul 25, 2024
710fb79
tests
fkiraly Jul 25, 2024
05fc210
handle soft deps
fkiraly Jul 25, 2024
84410d9
handle softdeps
fkiraly Jul 25, 2024
7e60251
manage soft deps
fkiraly Jul 25, 2024
33df864
bugfix
fkiraly Jul 25, 2024
39e6223
Update _check.py
fkiraly Aug 12, 2024
84b42a7
Merge branch 'main' into mtype-refactor
fkiraly Aug 12, 2024
7a5a028
Merge branch 'main' into mtype-refactor
fkiraly Sep 7, 2024
698774b
revert test_polars changes
fkiraly Sep 7, 2024
2a9b2b5
Update test_polars.py
fkiraly Sep 7, 2024
f317b82
Update test_polars.py
fkiraly Sep 7, 2024
f60aee2
Update test_polars.py
fkiraly Sep 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions skpro/datatypes/_base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Base module for datatypes."""

from skpro.datatypes._base._base import BaseConverter, BaseDatatype

__all__ = ["BaseConverter", "BaseDatatype"]
357 changes: 357 additions & 0 deletions skpro/datatypes/_base/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""Base class for data types."""

__author__ = ["fkiraly"]

from skpro.base import BaseObject
from skpro.datatypes._common import _ret
from skpro.utils.deep_equals import deep_equals


class BaseDatatype(BaseObject):
"""Base class for data types.

This class is the base class for all data types in sktime.
"""

_tags = {
"object_type": "datatype",
"scitype": None,
"name": None, # any string
"name_python": None, # lower_snake_case
"name_aliases": [],
"python_version": None,
"python_dependencies": None,
}

def __init__(self):
super().__init__()

# call defaults to check
def __call__(self, obj, return_metadata=False, var_name="obj"):
"""Check if obj is of this data type.

Parameters
----------
obj : any
Object to check.
return_metadata : bool, optional (default=False)
Whether to return metadata.
var_name : str, optional (default="obj")
Name of the variable to check, for use in error messages.

Returns
-------
valid : bool
Whether obj is of this data type.
msg : str, only returned if return_metadata is True.
Error message if obj is not of this data type.
metadata : instance of self only returned if return_metadata is True.
Metadata dictionary.
"""
return self._check(obj=obj, return_metadata=return_metadata, var_name=var_name)

def check(self, obj, return_metadata=False, var_name="obj"):
"""Check if obj is of this data type.

If self has parameters set, the check will in addition
check whether metadata of obj is equal to self's parameters.
In this case, ``return_metadata`` will always include the
metadata fields required to check the parameters.

Parameters
----------
obj : any
Object to check.
return_metadata : bool, optional (default=False)
Whether to return metadata.
var_name : str, optional (default="obj")
Name of the variable to check, for use in error messages.

Returns
-------
valid : bool
Whether obj is of this data type.
msg : str, only returned if return_metadata is True.
Error message if obj is not of this data type.
metadata : instance of self only returned if return_metadata is True.
Metadata dictionary.
"""
self_params = self.get_params()

need_check = [k for k in self_params if self_params[k] is not None]
self_dict = {k: self_params[k] for k in need_check}
Comment on lines +80 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick question - what is this code used for exactly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a feature, where you instantiate the type with arguments. This defines a subtype, e.g., "numpy 2D without nans", and then checks against the subtype, rather than just the overall type.


return_metadata_orig = return_metadata

# update return_metadata to retrieve any self_params
# return_metadata_bool has updated condition
if not len(need_check) == 0:
if isinstance(return_metadata, bool):
if not return_metadata:
return_metadata = need_check
return_metadata_bool = True
else:
return_metadata = set(return_metadata).union(need_check)
return_metadata = list(return_metadata)
return_metadata_bool = True
elif isinstance(return_metadata, bool):
return_metadata_bool = return_metadata
else:
return_metadata_bool = True

# call inner _check
check_res = self._check(
obj=obj, return_metadata=return_metadata, var_name=var_name
)

if return_metadata_bool:
valid = check_res[0]
msg = check_res[1]
metadata = check_res[2]
else:
valid = check_res
msg = ""

if not valid:
return _ret(False, msg, None, return_metadata_orig)

# now we know the check is valid, but we need to compare fields
metadata_sub = {k: metadata[k] for k in self_dict}
eqs, msg = deep_equals(self_dict, metadata_sub, return_msg=True)
if not eqs:
msg = f"metadata of type unequal, {msg}"
return _ret(False, msg, None, return_metadata_orig)

self_type = type(self)(**metadata)
return _ret(True, "", self_type, return_metadata_orig)

def _check(self, obj, return_metadata=False, var_name="obj"):
"""Check if obj is of this data type.

Parameters
----------
obj : any
Object to check.
return_metadata : bool, optional (default=False)
Whether to return metadata.
var_name : str, optional (default="obj")
Name of the variable to check, for use in error messages.

Returns
-------
valid : bool
Whether obj is of this data type.
msg : str, only returned if return_metadata is True.
Error message if obj is not of this data type.
metadata : dict, only returned if return_metadata is True.
Metadata dictionary.
"""
raise NotImplementedError

def __getitem__(self, key):
"""Get attribute by key.

Parameters
----------
key : str
Attribute name.

Returns
-------
value : any
Attribute value.
"""
return getattr(self, key)

def get(self, key, default=None):
"""Get attribute by key.

Parameters
----------
key : str
Attribute name.
default : any, optional (default=None)
Default value if attribute does not exist.

Returns
-------
value : any
Attribute value.
"""
return getattr(self, key, default)

def _get_key(self):
"""Get unique dictionary key corresponding to self.

Private function, used in collecting a dictionary of checks.
"""
mtype = self.get_class_tag("name")
scitype = self.get_class_tag("scitype")
return (mtype, scitype)


class BaseConverter(BaseObject):
"""Base class for data type converters.

This class is the base class for all data type converters in sktime.
"""

_tags = {
"object_type": "converter",
"mtype_from": None, # type to convert from - BaseDatatype class or str
"mtype_to": None, # type to convert to - BaseDatatype class or str
"multiple_conversions": False, # whether converter encodes multiple conversions
"python_version": None,
"python_dependencies": None,
}

def __init__(self, mtype_from=None, mtype_to=None):
self.mtype_from = mtype_from
self.mtype_to = mtype_to
super().__init__()

if mtype_from is not None:
self.set_tags(**{"mtype_from": mtype_from})
if mtype_to is not None:
self.set_tags(**{"mtype_to": mtype_to})

mtype_from = self.get_tag("mtype_from")
mtype_to = self.get_tag("mtype_to")

if mtype_from is None:
raise ValueError(
f"Error in instantiating {self.__class__.__name__}: "
"mtype_from and mtype_to must be set if the class has no defaults. "
"For valid pairs of defaults, use get_conversions."
)
if mtype_to is None:
raise ValueError(
f"Error in instantiating {self.__class__.__name__}: "
"mtype_to must be set in constructor, as the class has no defaults. "
"For valid pairs of defaults, use get_conversions."
)
if (mtype_from, mtype_to) not in self.__class__.get_conversions():
raise ValueError(
f"Error in instantiating {self.__class__.__name__}: "
"mtype_from and mtype_to must be a valid pair of defaults. "
"For valid pairs of defaults, use get_conversions."
)

# call defaults to convert
def __call__(self, obj, store=None):
"""Convert obj to another machine type.

Parameters
----------
obj : any
Object to convert.
store : dict, optional (default=None)
Reference of storage for lossy conversions.

Returns
-------
converted_obj : any
Object obj converted to another machine type.
"""
return self.convert(obj=obj, store=store)

def convert(self, obj, store=None):
"""Convert obj to another machine type.

Parameters
----------
obj : any
Object to convert.
store : dict, optional (default=None)
Reference of storage for lossy conversions.
"""
return self._convert(obj, store)

def _convert(self, obj, store=None):
"""Convert obj to another machine type.

Parameters
----------
obj : any
Object to convert.
store : dict, optional (default=None)
Reference of storage for lossy conversions.
"""
raise NotImplementedError

@classmethod
def get_conversions(cls):
"""Get all conversions.

Returns
-------
list of tuples (BaseDatatype subclass, BaseDatatype subclass)
List of all conversions in this class.
"""
cls_from = cls.get_class_tag("mtype_from")
cls_to = cls.get_class_tag("mtype_to")

if cls_from is not None and cls_to is not None:
return [(cls_from, cls_to)]
# if multiple conversions are encoded, this should be overridden
raise NotImplementedError

def _get_cls_from_to(self):
"""Get classes from and to.

Returns
-------
cls_from : BaseDatatype subclass
Class to convert from.
cls_to : BaseDatatype subclass
Class to convert to.
"""
cls_from = self.get_tag("mtype_from")
cls_to = self.get_tag("mtype_to")

cls_from = _coerce_str_to_cls(cls_from)
cls_to = _coerce_str_to_cls(cls_to)

return cls_from, cls_to

def _get_key(self):
"""Get unique dictionary key corresponding to self.

Private function, used in collecting a dictionary of checks.
"""
cls_from, cls_to = self._get_cls_from_to()

mtype_from = cls_from.get_class_tag("name")
mtype_to = cls_to.get_class_tag("name")
scitype = cls_to.get_class_tag("scitype")
return (mtype_from, mtype_to, scitype)


def _coerce_str_to_cls(cls_or_str):
"""Get class from string.

Parameters
----------
cls_or_str : str or class
Class or string. If string, assumed to be a unique mtype string from
one of the BaseDatatype subclasses.

Returns
-------
cls : cls_or_str, if was class; otherwise, class corresponding to string.
"""
if not isinstance(cls_or_str, str):
return cls_or_str

# otherwise, we use the string to get the class from the check dict
# perhaps it is nicer to transfer this to a registry later.
from skpro.datatypes._check import get_check_dict

cd = get_check_dict(soft_deps="all")
cls = [cd[k].__class__ for k in cd if k[0] == cls_or_str]
if len(cls) > 1:
raise ValueError(f"Error in converting string to class: {cls_or_str}")
elif len(cls) < 1:
return None
return cls[0]
Loading
Loading