-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] refactor datatypes
mtypes - checkers, converters
#392
Merged
Merged
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
7dcf9ac
Create retrieval.py
fkiraly f95d757
Merge branch 'main' into mtype-refactor
fkiraly 3657d94
base cls
fkiraly b35a1fa
base table
fkiraly 00a57b9
docstr
fkiraly df00cfd
Merge branch 'main' into mtype-refactor
fkiraly f0f80fe
Revert "[MNT] increase `numpy` bound to `numpy < 2.1`, `numpy 2` comp…
fkiraly cb3ae05
Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `numpy 2` com…
fkiraly 449b579
Revert "Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `nump…
fkiraly e16e87a
Reapply "Reapply "[MNT] increase `numpy` bound to `numpy < 2.1`, `num…
fkiraly 77dff4c
Merge branch 'main' into mtype-refactor
fkiraly aa68bb8
continuing refactor
fkiraly 64acb3e
Revert "continuing refactor"
fkiraly 5a05ca2
Reapply "continuing refactor"
fkiraly 3b6cdb9
linting
fkiraly 65a89fb
remove imports
fkiraly b1f1abd
docstrings
fkiraly d6903a0
fix tags
fkiraly dbf08ef
simplify retrieval
fkiraly 9d41a7f
Update _check.py
fkiraly adcad09
Update _check.py
fkiraly 649f1ae
convert
fkiraly d8e3130
Update _check.py
fkiraly 0acb1e0
use _get_key in type
fkiraly 2047e6b
converter base
fkiraly 3df7025
base clses
fkiraly 7f5bf90
cache checkers
fkiraly 6c24d08
converters - experimental
fkiraly 47fbb6c
Update _convert.py
fkiraly a2b91cd
linting
fkiraly 97234e9
Update _convert.py
fkiraly 304819e
Update _base.py
fkiraly 0faf0b8
bugfixes
fkiraly 710fb79
tests
fkiraly 05fc210
handle soft deps
fkiraly 84410d9
handle softdeps
fkiraly 7e60251
manage soft deps
fkiraly 33df864
bugfix
fkiraly 39e6223
Update _check.py
fkiraly 84b42a7
Merge branch 'main' into mtype-refactor
fkiraly 7a5a028
Merge branch 'main' into mtype-refactor
fkiraly 698774b
revert test_polars changes
fkiraly 2a9b2b5
Update test_polars.py
fkiraly f317b82
Update test_polars.py
fkiraly f60aee2
Update test_polars.py
fkiraly File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Base module for datatypes.""" | ||
|
||
from skpro.datatypes._base._base import BaseConverter, BaseDatatype | ||
|
||
__all__ = ["BaseConverter", "BaseDatatype"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,357 @@ | ||
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) | ||
"""Base class for data types.""" | ||
|
||
__author__ = ["fkiraly"] | ||
|
||
from skpro.base import BaseObject | ||
from skpro.datatypes._common import _ret | ||
from skpro.utils.deep_equals import deep_equals | ||
|
||
|
||
class BaseDatatype(BaseObject): | ||
"""Base class for data types. | ||
|
||
This class is the base class for all data types in sktime. | ||
""" | ||
|
||
_tags = { | ||
"object_type": "datatype", | ||
"scitype": None, | ||
"name": None, # any string | ||
"name_python": None, # lower_snake_case | ||
"name_aliases": [], | ||
"python_version": None, | ||
"python_dependencies": None, | ||
} | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
# call defaults to check | ||
def __call__(self, obj, return_metadata=False, var_name="obj"): | ||
"""Check if obj is of this data type. | ||
|
||
Parameters | ||
---------- | ||
obj : any | ||
Object to check. | ||
return_metadata : bool, optional (default=False) | ||
Whether to return metadata. | ||
var_name : str, optional (default="obj") | ||
Name of the variable to check, for use in error messages. | ||
|
||
Returns | ||
------- | ||
valid : bool | ||
Whether obj is of this data type. | ||
msg : str, only returned if return_metadata is True. | ||
Error message if obj is not of this data type. | ||
metadata : instance of self only returned if return_metadata is True. | ||
Metadata dictionary. | ||
""" | ||
return self._check(obj=obj, return_metadata=return_metadata, var_name=var_name) | ||
|
||
def check(self, obj, return_metadata=False, var_name="obj"): | ||
"""Check if obj is of this data type. | ||
|
||
If self has parameters set, the check will in addition | ||
check whether metadata of obj is equal to self's parameters. | ||
In this case, ``return_metadata`` will always include the | ||
metadata fields required to check the parameters. | ||
|
||
Parameters | ||
---------- | ||
obj : any | ||
Object to check. | ||
return_metadata : bool, optional (default=False) | ||
Whether to return metadata. | ||
var_name : str, optional (default="obj") | ||
Name of the variable to check, for use in error messages. | ||
|
||
Returns | ||
------- | ||
valid : bool | ||
Whether obj is of this data type. | ||
msg : str, only returned if return_metadata is True. | ||
Error message if obj is not of this data type. | ||
metadata : instance of self only returned if return_metadata is True. | ||
Metadata dictionary. | ||
""" | ||
self_params = self.get_params() | ||
|
||
need_check = [k for k in self_params if self_params[k] is not None] | ||
self_dict = {k: self_params[k] for k in need_check} | ||
|
||
return_metadata_orig = return_metadata | ||
|
||
# update return_metadata to retrieve any self_params | ||
# return_metadata_bool has updated condition | ||
if not len(need_check) == 0: | ||
if isinstance(return_metadata, bool): | ||
if not return_metadata: | ||
return_metadata = need_check | ||
return_metadata_bool = True | ||
else: | ||
return_metadata = set(return_metadata).union(need_check) | ||
return_metadata = list(return_metadata) | ||
return_metadata_bool = True | ||
elif isinstance(return_metadata, bool): | ||
return_metadata_bool = return_metadata | ||
else: | ||
return_metadata_bool = True | ||
|
||
# call inner _check | ||
check_res = self._check( | ||
obj=obj, return_metadata=return_metadata, var_name=var_name | ||
) | ||
|
||
if return_metadata_bool: | ||
valid = check_res[0] | ||
msg = check_res[1] | ||
metadata = check_res[2] | ||
else: | ||
valid = check_res | ||
msg = "" | ||
|
||
if not valid: | ||
return _ret(False, msg, None, return_metadata_orig) | ||
|
||
# now we know the check is valid, but we need to compare fields | ||
metadata_sub = {k: metadata[k] for k in self_dict} | ||
eqs, msg = deep_equals(self_dict, metadata_sub, return_msg=True) | ||
if not eqs: | ||
msg = f"metadata of type unequal, {msg}" | ||
return _ret(False, msg, None, return_metadata_orig) | ||
|
||
self_type = type(self)(**metadata) | ||
return _ret(True, "", self_type, return_metadata_orig) | ||
|
||
def _check(self, obj, return_metadata=False, var_name="obj"): | ||
"""Check if obj is of this data type. | ||
|
||
Parameters | ||
---------- | ||
obj : any | ||
Object to check. | ||
return_metadata : bool, optional (default=False) | ||
Whether to return metadata. | ||
var_name : str, optional (default="obj") | ||
Name of the variable to check, for use in error messages. | ||
|
||
Returns | ||
------- | ||
valid : bool | ||
Whether obj is of this data type. | ||
msg : str, only returned if return_metadata is True. | ||
Error message if obj is not of this data type. | ||
metadata : dict, only returned if return_metadata is True. | ||
Metadata dictionary. | ||
""" | ||
raise NotImplementedError | ||
|
||
def __getitem__(self, key): | ||
"""Get attribute by key. | ||
|
||
Parameters | ||
---------- | ||
key : str | ||
Attribute name. | ||
|
||
Returns | ||
------- | ||
value : any | ||
Attribute value. | ||
""" | ||
return getattr(self, key) | ||
|
||
def get(self, key, default=None): | ||
"""Get attribute by key. | ||
|
||
Parameters | ||
---------- | ||
key : str | ||
Attribute name. | ||
default : any, optional (default=None) | ||
Default value if attribute does not exist. | ||
|
||
Returns | ||
------- | ||
value : any | ||
Attribute value. | ||
""" | ||
return getattr(self, key, default) | ||
|
||
def _get_key(self): | ||
"""Get unique dictionary key corresponding to self. | ||
|
||
Private function, used in collecting a dictionary of checks. | ||
""" | ||
mtype = self.get_class_tag("name") | ||
scitype = self.get_class_tag("scitype") | ||
return (mtype, scitype) | ||
|
||
|
||
class BaseConverter(BaseObject): | ||
"""Base class for data type converters. | ||
|
||
This class is the base class for all data type converters in sktime. | ||
""" | ||
|
||
_tags = { | ||
"object_type": "converter", | ||
"mtype_from": None, # type to convert from - BaseDatatype class or str | ||
"mtype_to": None, # type to convert to - BaseDatatype class or str | ||
"multiple_conversions": False, # whether converter encodes multiple conversions | ||
"python_version": None, | ||
"python_dependencies": None, | ||
} | ||
|
||
def __init__(self, mtype_from=None, mtype_to=None): | ||
self.mtype_from = mtype_from | ||
self.mtype_to = mtype_to | ||
super().__init__() | ||
|
||
if mtype_from is not None: | ||
self.set_tags(**{"mtype_from": mtype_from}) | ||
if mtype_to is not None: | ||
self.set_tags(**{"mtype_to": mtype_to}) | ||
|
||
mtype_from = self.get_tag("mtype_from") | ||
mtype_to = self.get_tag("mtype_to") | ||
|
||
if mtype_from is None: | ||
raise ValueError( | ||
f"Error in instantiating {self.__class__.__name__}: " | ||
"mtype_from and mtype_to must be set if the class has no defaults. " | ||
"For valid pairs of defaults, use get_conversions." | ||
) | ||
if mtype_to is None: | ||
raise ValueError( | ||
f"Error in instantiating {self.__class__.__name__}: " | ||
"mtype_to must be set in constructor, as the class has no defaults. " | ||
"For valid pairs of defaults, use get_conversions." | ||
) | ||
if (mtype_from, mtype_to) not in self.__class__.get_conversions(): | ||
raise ValueError( | ||
f"Error in instantiating {self.__class__.__name__}: " | ||
"mtype_from and mtype_to must be a valid pair of defaults. " | ||
"For valid pairs of defaults, use get_conversions." | ||
) | ||
|
||
# call defaults to convert | ||
def __call__(self, obj, store=None): | ||
"""Convert obj to another machine type. | ||
|
||
Parameters | ||
---------- | ||
obj : any | ||
Object to convert. | ||
store : dict, optional (default=None) | ||
Reference of storage for lossy conversions. | ||
|
||
Returns | ||
------- | ||
converted_obj : any | ||
Object obj converted to another machine type. | ||
""" | ||
return self.convert(obj=obj, store=store) | ||
|
||
def convert(self, obj, store=None): | ||
"""Convert obj to another machine type. | ||
|
||
Parameters | ||
---------- | ||
obj : any | ||
Object to convert. | ||
store : dict, optional (default=None) | ||
Reference of storage for lossy conversions. | ||
""" | ||
return self._convert(obj, store) | ||
|
||
def _convert(self, obj, store=None): | ||
"""Convert obj to another machine type. | ||
|
||
Parameters | ||
---------- | ||
obj : any | ||
Object to convert. | ||
store : dict, optional (default=None) | ||
Reference of storage for lossy conversions. | ||
""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
def get_conversions(cls): | ||
"""Get all conversions. | ||
|
||
Returns | ||
------- | ||
list of tuples (BaseDatatype subclass, BaseDatatype subclass) | ||
List of all conversions in this class. | ||
""" | ||
cls_from = cls.get_class_tag("mtype_from") | ||
cls_to = cls.get_class_tag("mtype_to") | ||
|
||
if cls_from is not None and cls_to is not None: | ||
return [(cls_from, cls_to)] | ||
# if multiple conversions are encoded, this should be overridden | ||
raise NotImplementedError | ||
|
||
def _get_cls_from_to(self): | ||
"""Get classes from and to. | ||
|
||
Returns | ||
------- | ||
cls_from : BaseDatatype subclass | ||
Class to convert from. | ||
cls_to : BaseDatatype subclass | ||
Class to convert to. | ||
""" | ||
cls_from = self.get_tag("mtype_from") | ||
cls_to = self.get_tag("mtype_to") | ||
|
||
cls_from = _coerce_str_to_cls(cls_from) | ||
cls_to = _coerce_str_to_cls(cls_to) | ||
|
||
return cls_from, cls_to | ||
|
||
def _get_key(self): | ||
"""Get unique dictionary key corresponding to self. | ||
|
||
Private function, used in collecting a dictionary of checks. | ||
""" | ||
cls_from, cls_to = self._get_cls_from_to() | ||
|
||
mtype_from = cls_from.get_class_tag("name") | ||
mtype_to = cls_to.get_class_tag("name") | ||
scitype = cls_to.get_class_tag("scitype") | ||
return (mtype_from, mtype_to, scitype) | ||
|
||
|
||
def _coerce_str_to_cls(cls_or_str): | ||
"""Get class from string. | ||
|
||
Parameters | ||
---------- | ||
cls_or_str : str or class | ||
Class or string. If string, assumed to be a unique mtype string from | ||
one of the BaseDatatype subclasses. | ||
|
||
Returns | ||
------- | ||
cls : cls_or_str, if was class; otherwise, class corresponding to string. | ||
""" | ||
if not isinstance(cls_or_str, str): | ||
return cls_or_str | ||
|
||
# otherwise, we use the string to get the class from the check dict | ||
# perhaps it is nicer to transfer this to a registry later. | ||
from skpro.datatypes._check import get_check_dict | ||
|
||
cd = get_check_dict(soft_deps="all") | ||
cls = [cd[k].__class__ for k in cd if k[0] == cls_or_str] | ||
if len(cls) > 1: | ||
raise ValueError(f"Error in converting string to class: {cls_or_str}") | ||
elif len(cls) < 1: | ||
return None | ||
return cls[0] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick question - what is this code used for exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a feature, where you instantiate the type with arguments. This defines a subtype, e.g., "numpy 2D without nans", and then checks against the subtype, rather than just the overall type.