diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/PowerTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/PowerTransformer.py deleted file mode 100644 index cb3eb2b54..000000000 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/PowerTransformer.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Any, Dict, Optional - -from ConfigSpace.configuration_space import ConfigurationSpace -from ConfigSpace.hyperparameters import ( - CategoricalHyperparameter, -) - -import numpy as np - -import sklearn.preprocessing -from sklearn.base import BaseEstimator - -from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.feature_preprocessing. \ - base_feature_preprocessor import autoPyTorchFeaturePreprocessingComponent -from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter - - -class PowerTransformer(autoPyTorchFeaturePreprocessingComponent): - def __init__(self, standardize: bool = True, - random_state: Optional[np.random.RandomState] = None): - self.standardize = standardize - - super().__init__(random_state=random_state) - - def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: - self.preprocessor['numerical'] = sklearn.preprocessing.PowerTransformer(method="yeo-johnson", - standardize=self.standardize, - copy=False) - return self - - @staticmethod - def get_properties(dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None) -> Dict[str, Any]: - return {'shortname': 'PowerTransformer', - 'name': 'Power Transformer', - 'handles_sparse': True} - - @staticmethod - def get_hyperparameter_search_space( - dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None, - standardize: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='standardize', - value_range=(True, False), - default_value=True, - ), - ) -> ConfigurationSpace: - cs = ConfigurationSpace() - add_hyperparameter(cs, standardize, CategoricalHyperparameter) - - return cs diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/__init__.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/__init__.py index a3937a626..68ed0678f 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/__init__.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/__init__.py @@ -72,7 +72,6 @@ def get_hyperparameter_search_space(self, 'RandomKitchenSinks', 'Nystroem', 'PolynomialFeatures', - 'PowerTransformer', 'TruncatedSVD', ] for default_ in defaults: diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/PowerTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/PowerTransformer.py new file mode 100644 index 000000000..7dd2502f9 --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/PowerTransformer.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np + +from sklearn.preprocessing import PowerTransformer as SklearnPowerTransformer + +from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.base_scaler import BaseScaler + + +class PowerTransformer(BaseScaler): + """ + Map data to as close to a Gaussian distribution as possible + in order to reduce variance and minimize skewness. + + Uses `yeo-johnson` power transform method. Also, data is normalised + to zero mean and unit variance. + """ + def __init__(self, + random_state: Optional[np.random.RandomState] = None): + super().__init__() + self.random_state = random_state + + def fit(self, X: Dict[str, Any], y: Any = None) -> BaseScaler: + + self.check_requirements(X, y) + + self.preprocessor['numerical'] = SklearnPowerTransformer(method='yeo-johnson', copy=False) + return self + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None + ) -> Dict[str, Union[str, bool]]: + return { + 'shortname': 'PowerTransformer', + 'name': 'PowerTransformer', + 'handles_sparse': False + } diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/QuantileTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/QuantileTransformer.py new file mode 100644 index 000000000..cc0b4fa7a --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/QuantileTransformer.py @@ -0,0 +1,73 @@ +from typing import Any, Dict, Optional, Union + +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + CategoricalHyperparameter, + UniformIntegerHyperparameter +) + +import numpy as np + +from sklearn.preprocessing import QuantileTransformer as SklearnQuantileTransformer + +from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.base_scaler import BaseScaler +from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter + + +class QuantileTransformer(BaseScaler): + """ + Transform the features to follow a uniform or a normal distribution + using quantiles information. + + For more details of each attribute, see: + https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.QuantileTransformer.html + """ + def __init__( + self, + n_quantiles: int = 1000, + output_distribution: str = "normal", # Literal["normal", "uniform"] + random_state: Optional[np.random.RandomState] = None + ): + super().__init__() + self.random_state = random_state + self.n_quantiles = n_quantiles + self.output_distribution = output_distribution + + def fit(self, X: Dict[str, Any], y: Any = None) -> BaseScaler: + + self.check_requirements(X, y) + + self.preprocessor['numerical'] = SklearnQuantileTransformer(n_quantiles=self.n_quantiles, + output_distribution=self.output_distribution, + copy=False) + return self + + @staticmethod + def get_hyperparameter_search_space( + dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None, + n_quantiles: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="n_quantiles", + value_range=(10, 2000), + default_value=1000, + ), + output_distribution: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="output_distribution", + value_range=("uniform", "normal"), + default_value="normal", + ) + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + # TODO parametrize like the Random Forest as n_quantiles = n_features^param + add_hyperparameter(cs, n_quantiles, UniformIntegerHyperparameter) + add_hyperparameter(cs, output_distribution, CategoricalHyperparameter) + + return cs + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None + ) -> Dict[str, Union[str, bool]]: + return { + 'shortname': 'QuantileTransformer', + 'name': 'QuantileTransformer', + 'handles_sparse': False + } diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/RobustScaler.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/RobustScaler.py new file mode 100644 index 000000000..2c59d77c2 --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/RobustScaler.py @@ -0,0 +1,73 @@ +from typing import Any, Dict, Optional, Union + +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + UniformFloatHyperparameter, +) + +import numpy as np + +from sklearn.preprocessing import RobustScaler as SklearnRobustScaler + +from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.base_scaler import BaseScaler +from autoPyTorch.utils.common import FitRequirement, HyperparameterSearchSpace, add_hyperparameter + + +class RobustScaler(BaseScaler): + """ + Remove the median and scale features according to the quantile_range to make + the features robust to outliers. + + For more details of the preprocessor, see: + https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html + """ + def __init__( + self, + q_min: float = 0.25, + q_max: float = 0.75, + random_state: Optional[np.random.RandomState] = None + ): + super().__init__() + self.add_fit_requirements([ + FitRequirement('issparse', (bool,), user_defined=True, dataset_property=True)]) + self.random_state = random_state + self.q_min = q_min + self.q_max = q_max + + def fit(self, X: Dict[str, Any], y: Any = None) -> BaseScaler: + + self.check_requirements(X, y) + with_centering = bool(not X['dataset_properties']['issparse']) + + self.preprocessor['numerical'] = SklearnRobustScaler(quantile_range=(self.q_min, self.q_max), + with_centering=with_centering, + copy=False) + + return self + + @staticmethod + def get_hyperparameter_search_space( + dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None, + q_min: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="q_min", + value_range=(0.001, 0.3), + default_value=0.25), + q_max: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="q_max", + value_range=(0.7, 0.999), + default_value=0.75) + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + add_hyperparameter(cs, q_min, UniformFloatHyperparameter) + add_hyperparameter(cs, q_max, UniformFloatHyperparameter) + + return cs + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None + ) -> Dict[str, Union[str, bool]]: + return { + 'shortname': 'RobustScaler', + 'name': 'RobustScaler', + 'handles_sparse': True + } diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/__init__.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/__init__.py index 082b17cb9..d4d3ffeb5 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/__init__.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/scaling/__init__.py @@ -66,9 +66,21 @@ def get_hyperparameter_search_space(self, raise ValueError("no scalers found, please add a scaler") if default is None: - defaults = ['StandardScaler', 'Normalizer', 'MinMaxScaler', 'NoScaler'] + defaults = [ + 'StandardScaler', + 'Normalizer', + 'MinMaxScaler', + 'PowerTransformer', + 'QuantileTransformer', + 'RobustScaler', + 'NoScaler' + ] for default_ in defaults: if default_ in available_scalers: + if include is not None and default_ not in include: + continue + if exclude is not None and default_ in exclude: + continue default = default_ break diff --git a/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py b/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py index 99fad6b1f..31f41a876 100644 --- a/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py +++ b/test/test_pipeline/components/preprocessing/test_feature_preprocessor.py @@ -20,7 +20,7 @@ def random_state(): return 11 -@pytest.fixture(params=['TruncatedSVD', 'PolynomialFeatures', 'PowerTransformer', +@pytest.fixture(params=['TruncatedSVD', 'PolynomialFeatures', 'Nystroem', 'KernelPCA', 'RandomKitchenSinks']) def preprocessor(request): return request.param diff --git a/test/test_pipeline/components/preprocessing/test_scalers.py b/test/test_pipeline/components/preprocessing/test_scalers.py index 94ba0f2dc..7cbc12b07 100644 --- a/test/test_pipeline/components/preprocessing/test_scalers.py +++ b/test/test_pipeline/components/preprocessing/test_scalers.py @@ -9,6 +9,11 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.MinMaxScaler import MinMaxScaler from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.NoScaler import NoScaler from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.Normalizer import Normalizer +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.PowerTransformer import \ + PowerTransformer +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.QuantileTransformer import \ + QuantileTransformer +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.RobustScaler import RobustScaler from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.StandardScaler import StandardScaler @@ -239,3 +244,163 @@ def test_none_scaler(self): self.assertIsInstance(X['scaler'], dict) self.assertIsNone(X['scaler']['categorical']) self.assertIsNone(X['scaler']['numerical']) + + +def test_power_transformer(): + data = np.array([[1, 2, 3], + [7, 8, 9], + [4, 5, 6], + [11, 12, 13], + [17, 18, 19], + [14, 15, 16]]) + train_indices = np.array([0, 2, 5]) + test_indices = np.array([1, 4, 3]) + categorical_columns = list() + numerical_columns = [0, 1, 2] + dataset_properties = {'categorical_columns': categorical_columns, + 'numerical_columns': numerical_columns, + 'issparse': False} + X = { + 'X_train': data[train_indices], + 'dataset_properties': dataset_properties + } + scaler_component = PowerTransformer() + + scaler_component = scaler_component.fit(X) + X = scaler_component.transform(X) + scaler = X['scaler']['numerical'] + + # check if the fit dictionary X is modified as expected + assert isinstance(X['scaler'], dict) + assert isinstance(scaler, BaseEstimator) + assert X['scaler']['categorical'] is None + + # make column transformer with returned encoder to fit on data + column_transformer = make_column_transformer((scaler, X['dataset_properties']['numerical_columns']), + remainder='passthrough') + column_transformer = column_transformer.fit(X['X_train']) + transformed = column_transformer.transform(data[test_indices]) + + assert_allclose(transformed, np.array([[0.531648, 0.522782, 0.515394], + [1.435794, 1.451064, 1.461685], + [0.993609, 1.001055, 1.005734]]), rtol=1e-06) + + +def test_robust_scaler(): + data = np.array([[1, 2, 3], + [7, 8, 9], + [4, 5, 6], + [11, 12, 13], + [17, 18, 19], + [14, 15, 16]]) + train_indices = np.array([0, 2, 5]) + test_indices = np.array([1, 4, 3]) + categorical_columns = list() + numerical_columns = [0, 1, 2] + dataset_properties = {'categorical_columns': categorical_columns, + 'numerical_columns': numerical_columns, + 'issparse': False} + X = { + 'X_train': data[train_indices], + 'dataset_properties': dataset_properties + } + scaler_component = RobustScaler() + + scaler_component = scaler_component.fit(X) + X = scaler_component.transform(X) + scaler = X['scaler']['numerical'] + + # check if the fit dictionary X is modified as expected + assert isinstance(X['scaler'], dict) + assert isinstance(scaler, BaseEstimator) + assert X['scaler']['categorical'] is None + + # make column transformer with returned encoder to fit on data + column_transformer = make_column_transformer((scaler, X['dataset_properties']['numerical_columns']), + remainder='passthrough') + column_transformer = column_transformer.fit(X['X_train']) + transformed = column_transformer.transform(data[test_indices]) + + assert_allclose(transformed, np.array([[100, 100, 100], + [433.33333333, 433.33333333, 433.33333333], + [233.33333333, 233.33333333, 233.33333333]])) + + +class TestQuantileTransformer(): + def test_quantile_transformer_uniform(self): + data = np.array([[1, 2, 3], + [7, 8, 9], + [4, 5, 6], + [11, 12, 13], + [17, 18, 19], + [14, 15, 16]]) + train_indices = np.array([0, 2, 5]) + test_indices = np.array([1, 4, 3]) + categorical_columns = list() + numerical_columns = [0, 1, 2] + dataset_properties = {'categorical_columns': categorical_columns, + 'numerical_columns': numerical_columns, + 'issparse': False} + X = { + 'X_train': data[train_indices], + 'dataset_properties': dataset_properties + } + scaler_component = QuantileTransformer(output_distribution='uniform') + + scaler_component = scaler_component.fit(X) + X = scaler_component.transform(X) + scaler = X['scaler']['numerical'] + + # check if the fit dictionary X is modified as expected + assert isinstance(X['scaler'], dict) + assert isinstance(scaler, BaseEstimator) + assert X['scaler']['categorical'] is None + + # make column transformer with returned encoder to fit on data + column_transformer = make_column_transformer((scaler, X['dataset_properties']['numerical_columns']), + remainder='passthrough') + column_transformer = column_transformer.fit(X['X_train']) + transformed = column_transformer.transform(data[test_indices]) + + assert_allclose(transformed, np.array([[0.65, 0.65, 0.65], + [1, 1, 1], + [0.85, 0.85, 0.85]]), rtol=1e-06) + + def test_quantile_transformer_normal(self): + data = np.array([[1, 2, 3], + [7, 8, 9], + [4, 5, 6], + [11, 12, 13], + [17, 18, 19], + [14, 15, 16]]) + train_indices = np.array([0, 2, 5]) + test_indices = np.array([1, 4, 3]) + categorical_columns = list() + numerical_columns = [0, 1, 2] + dataset_properties = {'categorical_columns': categorical_columns, + 'numerical_columns': numerical_columns, + 'issparse': False} + X = { + 'X_train': data[train_indices], + 'dataset_properties': dataset_properties + } + scaler_component = QuantileTransformer(output_distribution='normal') + + scaler_component = scaler_component.fit(X) + X = scaler_component.transform(X) + scaler = X['scaler']['numerical'] + + # check if the fit dictionary X is modified as expected + assert isinstance(X['scaler'], dict) + assert isinstance(scaler, BaseEstimator) + assert X['scaler']['categorical'] is None + + # make column transformer with returned encoder to fit on data + column_transformer = make_column_transformer((scaler, X['dataset_properties']['numerical_columns']), + remainder='passthrough') + column_transformer = column_transformer.fit(X['X_train']) + transformed = column_transformer.transform(data[test_indices]) + + assert_allclose(transformed, np.array([[0.38532, 0.38532, 0.38532], + [5.199338, 5.199338, 5.199338], + [1.036433, 1.036433, 1.036433]]), rtol=1e-05)