Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code and update docstring #385

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def fit(
train_set :
The dataset for model training, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for training, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand All @@ -358,7 +358,7 @@ def fit(
val_set :
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand All @@ -383,7 +383,7 @@ def predict(
test_set :
The dataset for model validating, should be a dictionary including keys as 'X',
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand Down
30 changes: 13 additions & 17 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def fit(
train_set :
The dataset for model training, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for training, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand All @@ -90,7 +90,7 @@ def fit(
val_set :
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand All @@ -113,15 +113,15 @@ def predict(
@abstractmethod
def classify(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.

Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
Expand All @@ -132,8 +132,7 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.

raise NotImplementedError


Expand Down Expand Up @@ -395,7 +394,7 @@ def fit(
train_set :
The dataset for model training, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for training, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand All @@ -404,7 +403,7 @@ def fit(
val_set :
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand All @@ -427,19 +426,17 @@ def predict(
@abstractmethod
def classify(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.

Warnings
--------
The method classify is deprecated. Please use `predict()` instead.


Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
Expand All @@ -450,6 +447,5 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.

raise NotImplementedError
2 changes: 1 addition & 1 deletion pypots/classification/brits/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class DatasetForBRITS(DatasetForBRITS_Imputation):
data :
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand Down
17 changes: 5 additions & 12 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ..base import BaseNNClassifier
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class BRITS(BaseNNClassifier):
Expand Down Expand Up @@ -257,19 +256,15 @@ def predict(

def classify(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.

Warnings
--------
The method classify is deprecated. Please use `predict()` instead.

Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
Expand All @@ -280,8 +275,6 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
logger.warning(
"🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead."
)
result_dict = self.predict(X, file_type=file_type)

result_dict = self.predict(test_set, file_type=file_type)
return result_dict["classification"]
2 changes: 1 addition & 1 deletion pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DatasetForGRUD(BaseDataset):
data :
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand Down
17 changes: 5 additions & 12 deletions pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ..base import BaseNNClassifier
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class GRUD(BaseNNClassifier):
Expand Down Expand Up @@ -234,19 +233,15 @@ def predict(

def classify(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.

Warnings
--------
The method classify is deprecated. Please use `predict()` instead.

Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
Expand All @@ -257,8 +252,6 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
logger.warning(
"🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead."
)
result_dict = self.predict(X, file_type=file_type)

result_dict = self.predict(test_set, file_type=file_type)
return result_dict["classification"]
2 changes: 1 addition & 1 deletion pypots/classification/raindrop/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DatasetForRaindrop(DatasetForGRUD):
data :
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand Down
17 changes: 5 additions & 12 deletions pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ...classification.base import BaseNNClassifier
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class Raindrop(BaseNNClassifier):
Expand Down Expand Up @@ -279,19 +278,15 @@ def predict(

def classify(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Classify the input data with the trained model.

Warnings
--------
The method classify is deprecated. Please use `predict()` instead.

Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
Expand All @@ -302,8 +297,6 @@ def classify(
array-like, shape [n_samples],
Classification results of the given samples.
"""
logger.warning(
"🚨DeprecationWarning: The method classify is deprecated. Please use `predict` instead."
)
result_dict = self.predict(X, file_type=file_type)

result_dict = self.predict(test_set, file_type=file_type)
return result_dict["classification"]
7 changes: 7 additions & 0 deletions pypots/classification/template/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,10 @@ def predict(
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError

def classify(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError
30 changes: 13 additions & 17 deletions pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def fit(
train_set :
The dataset for model training, should be a dictionary including the key 'X',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for training, can contain missing values.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include the key 'X'.

val_set :
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand All @@ -112,15 +112,15 @@ def predict(
@abstractmethod
def cluster(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Cluster the input with the trained model.

Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
Expand All @@ -131,8 +131,7 @@ def cluster(
array-like,
Clustering results.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.

raise NotImplementedError


Expand Down Expand Up @@ -388,15 +387,15 @@ def fit(
train_set :
The dataset for model training, should be a dictionary including the key 'X',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for training, can contain missing values.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include the key 'X'.

val_set :
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand All @@ -419,19 +418,17 @@ def predict(
@abstractmethod
def cluster(
self,
X: Union[dict, str],
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> np.ndarray:
"""Cluster the input with the trained model.

Warnings
--------
The method cluster is deprecated. Please use `predict()` instead.


Parameters
----------
X :
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
test_set :
The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
n_features], or a path string locating a data file, e.g. h5 file.

file_type :
Expand All @@ -442,6 +439,5 @@ def cluster(
array-like,
Clustering results.
"""
# this is for old API compatibility, will be removed in the future.
# Please implement predict() instead.

raise NotImplementedError
2 changes: 1 addition & 1 deletion pypots/clustering/crli/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DatasetForCRLI(BaseDataset):
data :
The dataset for model input, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
which is time-series data for input, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
Expand Down
Loading
Loading