diff --git a/.gitignore b/.gitignore index 6709d3188..a11615f27 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,7 @@ dask-worker-space/ # Test output tmp/ .tmp_evaluation + +# Private file +grep.py +memo.txt diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 16f498e43..3cab42093 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -33,7 +33,7 @@ STRING_TO_TASK_TYPES, ) from autoPyTorch.datasets.base_dataset import BaseDataset -from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes +from autoPyTorch.datasets.train_val_split import CrossValTypes, HoldOutTypes from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager from autoPyTorch.ensemble.ensemble_selection import EnsembleSelection from autoPyTorch.ensemble.singlebest_ensemble import SingleBest @@ -175,8 +175,8 @@ def __init__( # By default try to use the TCP logging port or get a new port self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT - # Store the resampling strategy from the dataset, to load models as needed - self.resampling_strategy = None # type: Optional[Union[CrossValTypes, HoldoutValTypes]] + # Store the splitting type from the dataset, to load models as needed + self.splitting_type = None # type: Optional[Union[CrossValTypes, HoldOutTypes]] self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event] @@ -398,21 +398,21 @@ def _close_dask_client(self) -> None: self._is_dask_client_internally_created = False del self._is_dask_client_internally_created - def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] + def _load_models(self, splitting_type: Optional[Union[CrossValTypes, HoldOutTypes]] ) -> bool: """ Loads the models saved in the temporary directory during the smac run and the final ensemble created Args: - resampling_strategy (Union[CrossValTypes, HoldoutValTypes]): resampling strategy used to split the data + splitting_type (Union[CrossValTypes, HoldOutTypes]): splitting type used to split the data and to validate the performance of a candidate pipeline Returns: None """ - if resampling_strategy is None: - raise ValueError("Resampling strategy is needed to determine what models to load") + if splitting_type is None: + raise ValueError("Splitting type is needed to determine what models to load") self.ensemble_ = self._backend.load_ensemble(self.seed) # If no ensemble is loaded, try to get the best performing model @@ -422,10 +422,10 @@ def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, Holdou if self.ensemble_: identifiers = self.ensemble_.get_selected_model_identifiers() self.models_ = self._backend.load_models_by_identifiers(identifiers) - if isinstance(resampling_strategy, CrossValTypes): + if isinstance(splitting_type, CrossValTypes): self.cv_models_ = self._backend.load_cv_models_by_identifiers(identifiers) - if isinstance(resampling_strategy, CrossValTypes): + if isinstance(splitting_type, CrossValTypes): if len(self.cv_models_) == 0: raise ValueError('No models fitted!') @@ -705,7 +705,7 @@ def search( dataset_properties = dataset.get_dataset_properties(dataset_requirements) self._stopwatch.start_task(experiment_task_name) self.dataset_name = dataset.dataset_name - self.resampling_strategy = dataset.resampling_strategy + self.splitting_type = dataset.splitting_type self._logger = self._get_logger(self.dataset_name) self._all_supported_metrics = all_supported_metrics self._disable_file_output = disable_file_output @@ -869,7 +869,7 @@ def search( if load_models: self._logger.info("Loading models...") - self._load_models(dataset.resampling_strategy) + self._load_models(dataset.splitting_type) self._logger.info("Finished loading models...") # Clean up the logger @@ -927,7 +927,7 @@ def refit( }) X.update({**self.pipeline_options, **budget_config}) if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None: - self._load_models(dataset.resampling_strategy) + self._load_models(dataset.splitting_type) # Refit is not applicable when ensemble_size is set to zero. if self.ensemble_ is None: @@ -1025,7 +1025,7 @@ def predict( if self._logger is None: self._logger = self._get_logger("Predict-Logger") - if self.ensemble_ is None and not self._load_models(self.resampling_strategy): + if self.ensemble_ is None and not self._load_models(self.splitting_type): raise ValueError("No ensemble found. Either fit has not yet " "been called or no ensemble was fitted") @@ -1033,9 +1033,9 @@ def predict( assert self.ensemble_ is not None, "Load models should error out if no ensemble" self.ensemble_ = cast(Union[SingleBest, EnsembleSelection], self.ensemble_) - if isinstance(self.resampling_strategy, HoldoutValTypes): + if isinstance(self.splitting_type, HoldOutTypes): models = self.models_ - elif isinstance(self.resampling_strategy, CrossValTypes): + elif isinstance(self.splitting_type, CrossValTypes): models = self.cv_models_ all_predictions = joblib.Parallel(n_jobs=n_jobs)( diff --git a/autoPyTorch/constants.py b/autoPyTorch/constants.py index 652a546b9..f524d9b4b 100644 --- a/autoPyTorch/constants.py +++ b/autoPyTorch/constants.py @@ -1,3 +1,87 @@ +"""Constant numbers for this package + +TODO: + * Makes this file nicer + * transfer to proper locations e.g. create task directory? +""" + +from enum import Enum +from autoPyTorch.pipeline.image_classification import ImageClassificationPipeline +from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline +from autoPyTorch.pipeline.tabular_regression import TabularRegressionPipeline +import abc +from typing import Union + + +SupportedPipelines = Union[ImageClassificationPipeline, + TabularClassificationPipeline, + TabularRegressionPipeline] + + +class BaseTaskTypes(metaclass=abc.ABCMeta): + @abc.abstractmethod + def is_supported(self) -> bool: + raise NotImplementedError + + @abc.abstractmethod + def task_name(self) -> str: + raise NotImplementedError + + @abc.abstractmethod + def dataset_type(self) -> str: + raise NotImplementedError + + @abc.abstractmethod + def pipeline(self) -> SupportedPipelines: + raise NotImplementedError + + +class RegressionTypes(Enum, BaseTaskTypes): + tabular = TabularRegressionPipeline + image = None + time_series = None + + def is_supported(self) -> bool: + return self.value is not None + + def task_name(self) -> str: + return 'regressor' + + def dataset_type(self) -> str: + return self.name + + def pipeline(self) -> SupportedPipelines: + if not self.is_supported(): + raise ValueError(f"{self.name} is not supported pipeline.") + + return self.value + + +class ClassificationTypes(Enum, BaseTaskTypes): + tabular = TabularClassificationPipeline + image = ImageClassificationPipeline + time_series = None + + def is_supported(self) -> bool: + return self.value is not None + + def task_name(self) -> str: + return 'classifier' + + def dataset_type(self) -> str: + return self.name + + def pipeline(self) -> SupportedPipelines: + if not self.is_supported(): + raise ValueError(f"{self.name} is not supported pipeline.") + + return self.value + + +SupportedTaskTypes = (RegressionTypes, ClassificationTypes) + + +"""TODO: remove these variables TABULAR_CLASSIFICATION = 1 IMAGE_CLASSIFICATION = 2 TABULAR_REGRESSION = 3 @@ -27,6 +111,8 @@ 'image_regression': IMAGE_REGRESSION, 'time_series_classification': TIMESERIES_CLASSIFICATION, 'time_series_regression': TIMESERIES_REGRESSION} +""" + # Output types have been defined as in scikit-learn type_of_target # (https://scikit-learn.org/stable/modules/generated/sklearn.utils.multiclass.type_of_target.html) diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 565ffd4f3..4d1c9f9f9 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -1,5 +1,5 @@ from abc import ABCMeta -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Callable, NamedTuple import numpy as np @@ -11,42 +11,49 @@ import torchvision -from autoPyTorch.datasets.resampling_strategy import ( - CROSS_VAL_FN, +from autoPyTorch.datasets.train_val_split import ( + CrossValFuncs, CrossValTypes, - DEFAULT_RESAMPLING_PARAMETERS, - HOLDOUT_FN, - HoldoutValTypes, - get_cross_validators, - get_holdout_validators, - is_stratified, + CrossValParameters, + HoldOutFuncs, + HoldOutTypes, + HoldOutParameters ) -from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix +from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix, BaseNamedTuple -BASE_DATASET_INPUT = Union[Tuple[np.ndarray, np.ndarray], Dataset] +BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset] +SplitFunc = Callable[[int, np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]] def check_valid_data(data: Any) -> None: - if not (hasattr(data, '__getitem__') and hasattr(data, '__len__')): + if not all(hasattr(data, attr) for attr in ['__getitem__', '__len__']): raise ValueError( - 'The specified Data for Dataset does either not have a __getitem__ or a __len__ attribute.') + 'The specified Data for Dataset must have both __getitem__ and __len__ attribute.') -def type_check(train_tensors: BASE_DATASET_INPUT, val_tensors: Optional[BASE_DATASET_INPUT] = None) -> None: - for i in range(len(train_tensors)): - check_valid_data(train_tensors[i]) +def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None: + for train_tensor in train_tensors: + check_valid_data(train_tensor) if val_tensors is not None: - for i in range(len(val_tensors)): - check_valid_data(val_tensors[i]) + for val_tensor in val_tensors: + check_valid_data(val_tensor) class TransformSubset(Subset): - """ + """The title of the class description + TODO: + ```previous doc-string. (I did not understand the meaning.) Because the BaseDataset contains all the data (train/val/test), the transformations have to be applied with some directions. That is, if yielding train data, we expect to apply train transformation (which have augmentations exclusively). We achieve so by adding a train flag to the pytorch subset + ``` + + Attributes: + dataset (torch.utils.data.Dataset): The description + indices (Sequence[int]): The description + train (bool): If training or Validation. """ def __init__(self, dataset: Dataset, indices: Sequence[int], train: bool) -> None: self.dataset = dataset @@ -57,102 +64,123 @@ def __getitem__(self, idx: int) -> np.ndarray: return self.dataset.__getitem__(self.indices[idx], self.train) +class _DatasetCommonProperties(BaseNamedTuple, NamedTuple): + """TODO: doc-string""" + task_type: Optional[str] + output_type: str + issparse: bool + input_shape: Tuple[int] + output_shape: Tuple[int] + num_classes: Optional[int] + + class BaseDataset(Dataset, metaclass=ABCMeta): def __init__( self, - train_tensors: BASE_DATASET_INPUT, + train_tensors: BaseDatasetType, dataset_name: Optional[str] = None, - val_tensors: Optional[BASE_DATASET_INPUT] = None, - test_tensors: Optional[BASE_DATASET_INPUT] = None, - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: Optional[Dict[str, Any]] = None, + val_tensors: Optional[BaseDatasetType] = None, + test_tensors: Optional[BaseDatasetType] = None, + splitting_type: Union[str, CrossValTypes, HoldOutTypes] = HoldOutTypes.holdout_validation, + splitting_params: Optional[Dict[str, Any]] = None, shuffle: Optional[bool] = True, - seed: Optional[int] = 42, + random_state: Optional[int] = 42, train_transforms: Optional[torchvision.transforms.Compose] = None, val_transforms: Optional[torchvision.transforms.Compose] = None, ): - """ - Base class for datasets used in AutoPyTorch + """Base class for datasets used in AutoPyTorch Args: - train_tensors (A tuple of objects that have a __len__ and a __getitem__ attribute): - training data + train_tensors (A tuple of objects that have __len__ and __getitem__ attribute): + training data (A tuple of training features and labels) dataset_name (str): name of the dataset, used as experiment name. - val_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute): - validation data - test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute): - test data - resampling_strategy (Union[CrossValTypes, HoldoutValTypes]), - (default=HoldoutValTypes.holdout_validation): + val_tensors (An optional tuple of objects that have __len__ and __getitem__ attribute): + validation data (A tuple of validation features and labels) + test_tensors (An optional tuple of objects that have __len__ and __getitem__ attribute): + test data (A tuple of test features and labels) + splitting_type (Union[CrossValTypes, HoldOutTypes]), + (default=HoldOutTypes.holdout_validation): strategy to split the training data. - resampling_strategy_args (Optional[Dict[str, Any]]): arguments - required for the chosen resampling strategy. If None, uses - the default values provided in DEFAULT_RESAMPLING_PARAMETERS - in ```datasets/resampling_strategy.py```. + splitting_params (Optional[Dict[str, Any]]): + arguments required for the chosen splitting function. + If None, uses the default values provided in CrossValParameters and HoldOutParameters + in ```datasets/train_val_split.py```. shuffle: Whether to shuffle the data before performing splits - seed (int), (default=1): seed to be used for reproducibility. + seed (int), (default=42): seed to be used for reproducibility. train_transforms (Optional[torchvision.transforms.Compose]): Additional Transforms to be applied to the training data val_transforms (Optional[torchvision.transforms.Compose]): Additional Transforms to be applied to the validation/test data """ - if dataset_name is not None: - self.dataset_name = dataset_name - else: - self.dataset_name = hash_array_or_matrix(train_tensors[0]) + + self.dataset_name = dataset_name if dataset_name is not None \ + else hash_array_or_matrix(train_tensors[0]) + if not hasattr(train_tensors[0], 'shape'): type_check(train_tensors, val_tensors) - self.train_tensors = train_tensors - self.val_tensors = val_tensors - self.test_tensors = test_tensors - self.cross_validators: Dict[str, CROSS_VAL_FN] = {} - self.holdout_validators: Dict[str, HOLDOUT_FN] = {} - self.rand = np.random.RandomState(seed=seed) - self.shuffle = shuffle - self.resampling_strategy = resampling_strategy - self.resampling_strategy_args = resampling_strategy_args + + self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors + self.train_transform, self.val_transform = train_transforms, val_transforms + + self.random_state, self.shuffle = random_state, shuffle + self.rng = np.random.RandomState(seed=self.random_state) + + # Dict[str, SplitFunc] + self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes) + self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldOutTypes) + + cv_type, holdout_type = hasattr(CrossValTypes, splitting_type), hasattr(HoldOutTypes, splitting_type) + if not cv_type and not holdout_type: + raise NameError(f"Splitting type has no attribute {splitting_type}") + + self.splitting_type, = getattr(CrossValTypes, splitting_type) if cv_type \ + else getattr(HoldOutTypes, splitting_type) + self.splitting_params = splitting_params + self.convert_splitting_params_to_namedtuple() + self.splits = self.get_splits() + self.task_type: Optional[str] = None self.issparse: bool = issparse(self.train_tensors[0]) self.input_shape: Tuple[int] = train_tensors[0].shape[1:] self.num_classes: Optional[int] = None + if len(train_tensors) == 2 and train_tensors[1] is not None: self.output_type: str = type_of_target(self.train_tensors[1]) - self.output_shape: int = train_tensors[1].shape[1] if train_tensors[1].shape == 2 else 1 + self.output_shape: int = train_tensors[1].shape[1] if len(train_tensors[1].shape) == 2 else 1 # TODO: Look for a criteria to define small enough to preprocess self.is_small_preprocess = True - # Make sure cross validation splits are created once - self.cross_validators = get_cross_validators( - CrossValTypes.stratified_k_fold_cross_validation, - CrossValTypes.k_fold_cross_validation, - CrossValTypes.shuffle_split_cross_validation, - CrossValTypes.stratified_shuffle_split_cross_validation - ) - self.holdout_validators = get_holdout_validators( - HoldoutValTypes.holdout_validation, - HoldoutValTypes.stratified_holdout_validation - ) - self.splits = self.get_splits_from_resampling_strategy() - - # We also need to be able to transform the data, be it for pre-processing - # or for augmentation - self.train_transform = train_transforms - self.val_transform = val_transforms + def convert_splitting_params_to_namedtuple(self) -> None: + """convert splitting_params into CrossValParameters or HoldOutParameters""" + + if not isinstance(self.splitting_params, dict) and self.splitting_params is not None: + raise TypeError(f"splitting_params must be dict or None, but got {type(self.splitting_params)}") + + self.splitting_params = {} if self.splitting_params is None else self.splitting_params + + if isinstance(self.splitting_type, HoldOutTypes): + self.splitting_params = HoldOutParameters(**self.splitting_params, + random_state=self.random_state) + elif isinstance(self.splitting_type, CrossValTypes): + self.splitting_params = CrossValParameters(**self.splitting_params, + random_state=self.random_state) + else: + raise ValueError(f"splitting_type {self.splitting_type} is not supported.") def update_transform(self, transform: Optional[torchvision.transforms.Compose], - train: bool = True, - ) -> 'BaseDataset': + train: bool = True) -> 'BaseDataset': """ During the pipeline execution, the pipeline object might propose transformations as a product of the current pipeline configuration being tested. - This utility allows to return a self with the updated transformation, so that + This utility allows to return self with the updated transformation, so that a dataloader can yield this dataset with the desired transformations Args: - transform (torchvision.transforms.Compose): The transformations proposed - by the current pipeline - train (bool): Whether to update the train or validation transform + transform (torchvision.transforms.Compose): + The transformations proposed by the current pipeline + train (bool): + Whether to update the train or validation transform Returns: self: A copy of the update pipeline @@ -166,11 +194,11 @@ def update_transform(self, transform: Optional[torchvision.transforms.Compose], def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]: """ The base dataset uses a Subset of the data. Nevertheless, the base dataset expect - both validation and test data to be present in the same dataset, which motivated the - need to dynamically give train/test data with the __getitem__ command. + both validation and test data to be present in the same dataset, which is motivated + by the need to dynamically give train/test data with the __getitem__ command. This method yields a datapoint of the whole data (after a Subset has selected a given - item, based on the resampling strategy) and applies a train/testing transformation, if any. + item, based on the splitting functions) and applies a train/testing transformation, if any. Args: index (int): what element to yield from all the train/test tensors @@ -180,10 +208,8 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]: A transformed single point prediction """ - if hasattr(self.train_tensors[0], 'loc'): - X = self.train_tensors[0].iloc[[index]] - else: - X = self.train_tensors[0][index] + X = self.train_tensors[0].iloc[[index]] if hasattr(self.train_tensors[0], 'loc') \ + else self.train_tensors[0][index] if self.train_transform is not None and train: X = self.train_transform(X) @@ -191,125 +217,36 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]: X = self.val_transform(X) # In case of prediction, the targets are not provided - Y = self.train_tensors[1] - if Y is not None: - Y = Y[index] - else: - Y = None + Y = self.train_tensors[1][index] if self.train_tensors[1] is not None else None return X, Y - def __len__(self) -> int: - return self.train_tensors[0].shape[0] + def __len__(self) -> int: return self.train_tensors[0].shape[0] - def _get_indices(self) -> np.ndarray: - if self.shuffle: - indices = self.rand.permutation(len(self)) - else: - indices = np.arange(len(self)) - return indices + def _get_indices(self) -> np.ndarray: return self.rng.permutation(len(self)) if self.shuffle \ + else np.arange(len(self)) - def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]: + def get_splits(self) -> List[Tuple[List[int], List[int]]]: """ - Creates a set of splits based on a resampling strategy provided + Creates a set of splits based on a provided splitting function Returns (List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format """ - splits = [] - if isinstance(self.resampling_strategy, HoldoutValTypes): - val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get( - 'val_share', None) - if self.resampling_strategy_args is not None: - val_share = self.resampling_strategy_args.get('val_share', val_share) - splits.append( - self.create_holdout_val_split( - holdout_val_type=self.resampling_strategy, - val_share=val_share, - ) - ) - elif isinstance(self.resampling_strategy, CrossValTypes): - num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get( - 'num_splits', None) - if self.resampling_strategy_args is not None: - num_splits = self.resampling_strategy_args.get('num_splits', num_splits) - # Create the split if it was not created before - splits.extend( - self.create_cross_val_splits( - cross_val_type=self.resampling_strategy, - num_splits=cast(int, num_splits), - ) - ) - else: - raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}") - return splits - - def create_cross_val_splits( - self, - cross_val_type: CrossValTypes, - num_splits: int - ) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]: - """ - This function creates the cross validation split for the given task. - - It is done once per dataset to have comparable results among pipelines - Args: - cross_val_type (CrossValTypes): - num_splits (int): number of splits to be created - Returns: - (List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]): - list containing 'num_splits' splits. - """ - # Create just the split once - # This is gonna be called multiple times, because the current dataset - # is being used for multiple pipelines. That is, to be efficient with memory - # we dump the dataset to memory and read it on a need basis. So this function - # should be robust against multiple calls, and it does so by remembering the splits - if not isinstance(cross_val_type, CrossValTypes): - raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.') - kwargs = {} - if is_stratified(cross_val_type): - # we need additional information about the data for stratification - kwargs["stratify"] = self.train_tensors[-1] - splits = self.cross_validators[cross_val_type.name]( - num_splits, self._get_indices(), **kwargs) + stratify = self.train_tensors[-1] if self.splitting_type.is_stratified() else None + + """TODO: Think about the usage of validation data. It is not used now.""" + if isinstance(self.splitting_type, CrossValTypes): + splits = self.cross_validators[self.splitting_type.name](cv_params=self.splitting_params, + indices=self._get_indices(), + stratify=stratify) + elif isinstance(self.splitting_type, HoldOutTypes): + splits = self.holdout_validators[self.splitting_type.name](holdout_params=self.splitting_params, + indices=self._get_indices(), + stratify=stratify) return splits - def create_holdout_val_split( - self, - holdout_val_type: HoldoutValTypes, - val_share: float, - ) -> Tuple[np.ndarray, np.ndarray]: - """ - This function creates the holdout split for the given task. - - It is done once per dataset to have comparable results among pipelines - Args: - holdout_val_type (HoldoutValTypes): - val_share (float): share of the validation data - - Returns: - (Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices) - """ - if holdout_val_type is None: - raise ValueError( - '`val_share` specified, but `holdout_val_type` not specified.' - ) - if self.val_tensors is not None: - raise ValueError( - '`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.') - if val_share < 0 or val_share > 1: - raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.") - if not isinstance(holdout_val_type, HoldoutValTypes): - raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.') - kwargs = {} - if is_stratified(holdout_val_type): - # we need additional information about the data for stratification - kwargs["stratify"] = self.train_tensors[-1] - train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs) - return train, val - def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: """ The above split methods employ the Subset to internally subsample the whole dataset. @@ -327,12 +264,13 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: return (TransformSubset(self, self.splits[split_id][0], train=True), TransformSubset(self, self.splits[split_id][1], train=False)) - def replace_data(self, X_train: BASE_DATASET_INPUT, X_test: Optional[BASE_DATASET_INPUT]) -> 'BaseDataset': + def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset': """ To speed up the training of small dataset, early pre-processing of the data can be made on the fly by the pipeline. In this case, we replace the original train/test tensors by this pre-processed version + TODO: X_test is None => training is True? or validation step? Args: X_train (np.ndarray): the pre-processed (imputation/encoding/...) train data @@ -342,8 +280,9 @@ def replace_data(self, X_train: BASE_DATASET_INPUT, X_test: Optional[BASE_DATASE self """ self.train_tensors = (X_train, self.train_tensors[1]) - if X_test is not None and self.test_tensors is not None: - self.test_tensors = (X_test, self.test_tensors[1]) + self.test_tensors = (X_test, self.test_tensors[1]) if None not in [X_test, self.test_tensors] \ + else self.test_tensors + return self def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) -> Dict[str, Any]: @@ -355,19 +294,19 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) -> contain. Returns: - + dataset_properties (Dict[str, Any]): + Dict of the dataset properties. """ dataset_properties = dict() + # SHUHEI TODO: check dataset_requirements, FitRequirement for dataset_requirement in dataset_requirements: dataset_properties[dataset_requirement.name] = getattr(self, dataset_requirement.name) # Add task type, output type and issparse to dataset properties as - # they are not a dataset requirement in the pipeline - dataset_properties.update({'task_type': self.task_type, - 'output_type': self.output_type, - 'issparse': self.issparse, - 'input_shape': self.input_shape, - 'output_shape': self.output_shape, - 'num_classes': self.num_classes, - }) + # they are not dataset requirements in the pipeline + dataset_common_properties = _DatasetCommonProperties(task_type=self.task_type, output_type=self.output_type, + issparse=self.issparse, input_shape=self.input_shape, + output_shape=self.output_shape, + num_classes=self.num_classes) + dataset_properties.update(**dataset_common_properties._asdict()) return dataset_properties diff --git a/autoPyTorch/datasets/image_dataset.py b/autoPyTorch/datasets/image_dataset.py index 4664dbaf5..828d127f4 100644 --- a/autoPyTorch/datasets/image_dataset.py +++ b/autoPyTorch/datasets/image_dataset.py @@ -21,9 +21,9 @@ TASK_TYPES_TO_STRING, ) from autoPyTorch.datasets.base_dataset import BaseDataset -from autoPyTorch.datasets.resampling_strategy import ( +from autoPyTorch.datasets.train_val_split import ( CrossValTypes, - HoldoutValTypes, + HoldOutTypes, ) IMAGE_DATASET_INPUT = Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]] @@ -39,13 +39,13 @@ class ImageDataset(BaseDataset): validation data test (Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]): testing data - resampling_strategy (Union[CrossValTypes, HoldoutValTypes]), - (default=HoldoutValTypes.holdout_validation): + splitting_type (Union[str, CrossValTypes, HoldOutTypes]), + (default=HoldOutTypes.holdout_validation): strategy to split the training data. - resampling_strategy_args (Optional[Dict[str, Any]]): arguments - required for the chosen resampling strategy. If None, uses - the default values provided in DEFAULT_RESAMPLING_PARAMETERS - in ```datasets/resampling_strategy.py```. + splitting_params (Optional[Dict[str, Any]]): arguments + required for the chosen splitting type. If None, uses + the default values provided in the NamedTuple + in ```datasets/train_val_split.py```. shuffle: Whether to shuffle the data before performing splits seed (int), (default=1): seed to be used for reproducibility. train_transforms (Optional[torchvision.transforms.Compose]): @@ -57,8 +57,8 @@ def __init__(self, train: IMAGE_DATASET_INPUT, val: Optional[IMAGE_DATASET_INPUT] = None, test: Optional[IMAGE_DATASET_INPUT] = None, - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: Optional[Dict[str, Any]] = None, + splitting_type: Union[str, CrossValTypes, HoldOutTypes] = HoldOutTypes.holdout_validation, + splitting_params: Optional[Dict[str, Any]] = None, shuffle: Optional[bool] = True, seed: Optional[int] = 42, train_transforms: Optional[torchvision.transforms.Compose] = None, @@ -73,11 +73,10 @@ def __init__(self, self.mean, self.std = _calc_mean_std(train=train) super().__init__(train_tensors=train, val_tensors=val, test_tensors=test, shuffle=shuffle, - resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args, + splitting_type=splitting_type, splitting_params=splitting_params, seed=seed, train_transforms=train_transforms, - val_transforms=val_transforms, - ) + val_transforms=val_transforms) if self.output_type is not None: if STRING_TO_OUTPUT_TYPES[self.output_type] in CLASSIFICATION_OUTPUTS: self.task_type = TASK_TYPES_TO_STRING[IMAGE_CLASSIFICATION] diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py deleted file mode 100644 index 1d0bc3077..000000000 --- a/autoPyTorch/datasets/resampling_strategy.py +++ /dev/null @@ -1,153 +0,0 @@ -from enum import IntEnum -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np - -from sklearn.model_selection import ( - KFold, - ShuffleSplit, - StratifiedKFold, - StratifiedShuffleSplit, - TimeSeriesSplit, - train_test_split -) - -from typing_extensions import Protocol - - -# Use callback protocol as workaround, since callable with function fields count 'self' as argument -class CROSS_VAL_FN(Protocol): - def __call__(self, - num_splits: int, - indices: np.ndarray, - stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]: - ... - - -class HOLDOUT_FN(Protocol): - def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] - ) -> Tuple[np.ndarray, np.ndarray]: - ... - - -class CrossValTypes(IntEnum): - stratified_k_fold_cross_validation = 1 - k_fold_cross_validation = 2 - stratified_shuffle_split_cross_validation = 3 - shuffle_split_cross_validation = 4 - time_series_cross_validation = 5 - - -class HoldoutValTypes(IntEnum): - holdout_validation = 6 - stratified_holdout_validation = 7 - - -RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes] - -DEFAULT_RESAMPLING_PARAMETERS = { - HoldoutValTypes.holdout_validation: { - 'val_share': 0.33, - }, - HoldoutValTypes.stratified_holdout_validation: { - 'val_share': 0.33, - }, - CrossValTypes.k_fold_cross_validation: { - 'num_splits': 3, - }, - CrossValTypes.stratified_k_fold_cross_validation: { - 'num_splits': 3, - }, - CrossValTypes.shuffle_split_cross_validation: { - 'num_splits': 3, - }, - CrossValTypes.time_series_cross_validation: { - 'num_splits': 3, - }, -} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] - - -def get_cross_validators(*cross_val_types: CrossValTypes) -> Dict[str, CROSS_VAL_FN]: - cross_validators = {} # type: Dict[str, CROSS_VAL_FN] - for cross_val_type in cross_val_types: - cross_val_fn = globals()[cross_val_type.name] - cross_validators[cross_val_type.name] = cross_val_fn - return cross_validators - - -def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]: - holdout_validators = {} # type: Dict[str, HOLDOUT_FN] - for holdout_val_type in holdout_val_types: - holdout_val_fn = globals()[holdout_val_type.name] - holdout_validators[holdout_val_type.name] = holdout_val_fn - return holdout_validators - - -def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool: - if isinstance(val_type, str): - return val_type.lower().startswith("stratified") - else: - return val_type.name.lower().startswith("stratified") - - -def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=False) - return train, val - - -def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \ - -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=kwargs["stratify"]) - return train, val - - -def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = ShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits - - -def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - -def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedKFold(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - -def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) -> List[Tuple[np.ndarray, np.ndarray]]: - """ - Standard k fold cross validation. - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices - """ - cv = KFold(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits - - -def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - """ - Returns train and validation indices respecting the temporal ordering of the data. - Dummy example: [0, 1, 2, 3] with 3 folds yields - [0] [1] - [0, 1] [2] - [0, 1, 2] [3] - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices - """ - cv = TimeSeriesSplit(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits diff --git a/autoPyTorch/datasets/tabular_dataset.py b/autoPyTorch/datasets/tabular_dataset.py index ab75ce3f8..877d68d65 100644 --- a/autoPyTorch/datasets/tabular_dataset.py +++ b/autoPyTorch/datasets/tabular_dataset.py @@ -20,9 +20,9 @@ TASK_TYPES_TO_STRING, ) from autoPyTorch.datasets.base_dataset import BaseDataset -from autoPyTorch.datasets.resampling_strategy import ( +from autoPyTorch.datasets.train_val_split import ( CrossValTypes, - HoldoutValTypes, + HoldOutTypes, ) @@ -40,6 +40,7 @@ def __init__(self, values: list): def __getitem__(self, item: Any) -> int: if pd.isna(item): + # SHUHEI MEMO: Why do we add the location for nan? return 0 else: return self.values[item] + 1 @@ -53,13 +54,13 @@ class TabularDataset(BaseDataset): Y (Union[np.ndarray, pd.Series]): training data targets. X_test (Optional[Union[np.ndarray, pd.DataFrame]]): input testing data. Y_test (Optional[Union[np.ndarray, pd.DataFrame]]): testing data targets - resampling_strategy (Union[CrossValTypes, HoldoutValTypes]), - (default=HoldoutValTypes.holdout_validation): + splitting_type (Union[str, CrossValTypes, HoldOutTypes]), + (default=HoldOutTypes.holdout_validation): strategy to split the training data. - resampling_strategy_args (Optional[Dict[str, Any]]): arguments - required for the chosen resampling strategy. If None, uses - the default values provided in DEFAULT_RESAMPLING_PARAMETERS - in ```datasets/resampling_strategy.py```. + splitting_params (Optional[Dict[str, Any]]): arguments + required for the chosen splitting type. If None, uses + the default values provided in the NamedTuple + in ```datasets/train_val_split.py```. shuffle: Whether to shuffle the data before performing splits seed (int), (default=1): seed to be used for reproducibility. train_transforms (Optional[torchvision.transforms.Compose]): @@ -75,8 +76,8 @@ def __init__(self, X: Union[np.ndarray, pd.DataFrame], Y: Union[np.ndarray, pd.Series], X_test: Optional[Union[np.ndarray, pd.DataFrame]] = None, Y_test: Optional[Union[np.ndarray, pd.DataFrame]] = None, - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: Optional[Dict[str, Any]] = None, + splitting_type: Union[str, CrossValTypes, HoldOutTypes] = HoldOutTypes.holdout_validation, + splitting_params: Optional[Dict[str, Any]] = None, shuffle: Optional[bool] = True, seed: Optional[int] = 42, train_transforms: Optional[torchvision.transforms.Compose] = None, @@ -97,6 +98,7 @@ def __init__(self, X: Union[np.ndarray, pd.DataFrame], # the below function will simply return Pandas DataFrame. Y = check_array(Y, ensure_2d=False) + # SHUHEI MEMO: num_features overlaps with input_shape in BaseDataset self.categorical_columns, self.numerical_columns, self.categories, self.num_features = \ self.infer_dataset_properties(X) @@ -114,10 +116,11 @@ def __init__(self, X: Union[np.ndarray, pd.DataFrame], Y_test, assert_single_column=True) Y_test = check_array(Y_test, ensure_2d=False) + """TODO: rename the variable names""" super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test), shuffle=shuffle, - resampling_strategy=resampling_strategy, - resampling_strategy_args=resampling_strategy_args, - seed=seed, train_transforms=train_transforms, + splitting_type=splitting_type, + splitting_params=splitting_params, + random_state=seed, train_transforms=train_transforms, dataset_name=dataset_name, val_transforms=val_transforms) if self.output_type is not None: @@ -150,6 +153,12 @@ def interpret_columns(self, asserting that the data contains a single column Returns: + data (pd.DataFrame): Converted data + data_types (List[DataTypes]): Datatypes of each column + nan_mask (Union[np.ndarray]): locations of nan in data + itovs (List[Optional[list]]): The table value in the location (col, row) + vtois (List[Optional[Value2Index]]): The index of the value in the specified column + Tuple[pd.DataFrame, List[DataTypes], Union[np.ndarray], List[Optional[list]], @@ -157,6 +166,7 @@ def interpret_columns(self, """ single_column = False if isinstance(data, np.ndarray): + # SHUHEI MEMO: When does ',' not in str(data.dtype) happen? if len(data.shape) == 1 and ',' not in str(data.dtype): single_column = True data = data[:, None] @@ -176,6 +186,7 @@ def interpret_columns(self, data_types = [] nan_mask = data.isna().to_numpy() for col_index, dtype in enumerate(data.dtypes): + # SHUHEI MEMO: dtype.kind (https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html) if dtype.kind == 'f': data_types.append(DataTypes.Float) elif dtype.kind in ('i', 'u', 'b'): @@ -188,10 +199,12 @@ def interpret_columns(self, data_types.append(DataTypes.Categorical) else: raise ValueError(f"The dtype in column {col_index} is {dtype} which is not supported.") + # SHUHEI MEMO: index to value, value to index itovs: List[Optional[List[Any]]] = [] vtois: List[Optional[Value2Index]] = [] for col_index, (_, col) in enumerate(data.iteritems()): if data_types[col_index] != DataTypes.Float: + # SHUHEI MEMO: Since we are taking a set, no replacement, but why is it fine? non_na_values = [v for v in set(col) if not pd.isna(v)] non_na_values.sort() itovs.append([np.nan] + non_na_values) @@ -214,6 +227,11 @@ def infer_dataset_properties(self, X: Any) \ X: input training data Returns: + categorical_columns (List[int]): The list of indices specifying categorical columns + numerical_columns (List[int]): The list of indices specifying numerical columns + categories (List[object]): The list of choices of each category + num_features (int): The number of columns or features in a given tabular data + (Tuple[List[int], List[int], List[object], int]): """ categorical_columns = [] @@ -223,7 +241,8 @@ def infer_dataset_properties(self, X: Any) \ categorical_columns.append(i) else: numerical_columns.append(i) - categories = [np.unique(X.iloc[:, a]).tolist() for a in categorical_columns] + # SHUHEI MEMO: Why don't we make it dict? + categories = [np.unique(X.iloc[:, col_idx]).tolist() for col_idx in categorical_columns] num_features = X.shape[1] return categorical_columns, numerical_columns, categories, num_features diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 7b0435d19..11cbec7ce 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -5,11 +5,11 @@ import torchvision.transforms from autoPyTorch.datasets.base_dataset import BaseDataset -from autoPyTorch.datasets.resampling_strategy import ( +from autoPyTorch.datasets.train_val_split import ( CrossValTypes, - HoldoutValTypes, - get_cross_validators, - get_holdout_validators + HoldOutTypes, + CrossValFuncs, + HoldOutFuncs ) TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported @@ -24,8 +24,8 @@ def __init__(self, n_steps: int, train: TIME_SERIES_FORECASTING_INPUT, val: Optional[TIME_SERIES_FORECASTING_INPUT] = None, - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: Optional[Dict[str, Any]] = None, + splitting_type: Union[str, CrossValTypes, HoldOutTypes] = HoldOutTypes.holdout_validation, + splitting_params: Optional[Dict[str, Any]] = None, shuffle: Optional[bool] = False, seed: Optional[int] = 42, train_transforms: Optional[torchvision.transforms.Compose] = None, @@ -55,13 +55,13 @@ def __init__(self, sequence_length=sequence_length, n_steps=n_steps) super().__init__(train_tensors=train, val_tensors=val, shuffle=shuffle, - resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args, + splitting_type=splitting_type, splitting_params=splitting_params, seed=seed, train_transforms=train_transforms, val_transforms=val_transforms, ) - self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation) - self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation) + self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation) + self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldOutTypes.holdout_validation) def _check_time_series_forecasting_inputs(target_variables: Tuple[int], @@ -117,15 +117,15 @@ def __init__(self, val=val, task_type="time_series_classification") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.stratified_k_fold_cross_validation, CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation, CrossValTypes.stratified_shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( - HoldoutValTypes.holdout_validation, - HoldoutValTypes.stratified_holdout_validation + self.holdout_validators = HoldOutFuncs.get_holdout_validators( + HoldOutTypes.holdout_validation, + HoldOutTypes.stratified_holdout_validation ) @@ -135,12 +135,12 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np. val=val, task_type="time_series_regression") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( + self.cross_validators = CrossValFuncs.get_cross_validators( CrossValTypes.k_fold_cross_validation, CrossValTypes.shuffle_split_cross_validation ) - self.holdout_validators = get_holdout_validators( - HoldoutValTypes.holdout_validation + self.holdout_validators = HoldOutFuncs.get_holdout_validators( + HoldOutTypes.holdout_validation ) diff --git a/autoPyTorch/datasets/train_val_split.py b/autoPyTorch/datasets/train_val_split.py new file mode 100644 index 000000000..bfbb40a59 --- /dev/null +++ b/autoPyTorch/datasets/train_val_split.py @@ -0,0 +1,230 @@ +"""The title of the module description +* Describe at the beginning of the source code. +* Describe before the package imports + +TODO: + * add doc-string for each class +""" + +from enum import IntEnum +from typing import Any, Dict, List, Optional, Tuple, Union, Callable, NamedTuple + +import numpy as np + +from sklearn.model_selection import ( + KFold, + ShuffleSplit, + StratifiedKFold, + StratifiedShuffleSplit, + TimeSeriesSplit, + train_test_split +) + +from autoPyTorch.utils.common import BaseNamedTuple + + +SplitFunc = Callable[[int, np.ndarray, Any], List[Tuple[np.ndarray, np.ndarray]]] + + +"""TODO: Change to BaseDict""" +class CrossValParameters(BaseNamedTuple, NamedTuple): + """The parameters of cross validators + + Attributes: + n_splits (int): The number of splits for cross validation + random_state (int or None): The random seed + """ + n_splits: int = 3 + random_state: Optional[int] = 42 + + +class HoldOutParameters(BaseNamedTuple, NamedTuple): + """The parameters of hold out validators + + Attributes: + val_ratio (float): The ratio of validation size against training size + random_state (int or None): The random seed + """ + val_ratio: int = 0.33 + random_state: Optional[int] = 42 + + +class CrossValTypes(IntEnum): + """The type of cross validation + + This class is used to specify the cross validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> cv_type = CrossValTypes.k_fold_cross_validation + >>> print(cv_type.name) + + k_fold_cross_validation + + >>> for cross_val_type in CrossValTypes: + print(cross_val_type.name, cross_val_type.value) + + stratified_k_fold_cross_validation 1 + k_fold_cross_validation 2 + stratified_shuffle_split_cross_validation 3 + shuffle_split_cross_validation 4 + time_series_cross_validation 5 + """ + stratified_k_fold_cross_validation = 1 + k_fold_cross_validation = 2 + stratified_shuffle_split_cross_validation = 3 + shuffle_split_cross_validation = 4 + time_series_cross_validation = 5 + + def is_stratified(self) -> bool: + stratified = [self.stratified_k_fold_cross_validation, + self.stratified_shuffle_split_cross_validation] + return getattr(self, self.name) in stratified + + +class HoldOutTypes(IntEnum): + """The type of hold out validation (refer to CrossValTypes' doc-string)""" + holdout_validation = 6 + stratified_holdout_validation = 7 + + def is_stratified(self) -> bool: + stratified = [self.stratified_holdout_validation] + return getattr(self, self.name) in stratified + + +def not_implemented_stratify(stratify: np.ndarray) -> None: + if stratify is None: + raise ValueError("stratify (label data) required as input") + + +class CrossValFuncs(): + @staticmethod + def input_warning(cv_params: CrossValParameters): + if type(cv_params.n_splits) is not int: + raise TypeError("n_splits for cross validation must be integer.") + + @staticmethod + def shuffle_split_cross_validation(indices: np.ndarray, stratify: Optional[np.ndarray], + cv_params: Union[Dict[str, Any], CrossValParameters]) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + + cv_params = CrossValParameters(**cv_params) if isinstance(cv_params, dict) else cv_params + CrossValFuncs.input_warning(cv_params) + + cv = ShuffleSplit(n_splits=cv_params.n_splits, random_state=cv_params.random_state) + splits = list(cv.split(indices)) + return splits + + @staticmethod + def stratified_shuffle_split_cross_validation(indices: np.ndarray, stratify: Optional[np.ndarray], + cv_params: Union[Dict[str, Any], CrossValParameters]) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + + cv_params = CrossValParameters(**cv_params) if isinstance(cv_params, dict) else cv_params + CrossValFuncs.input_warning(cv_params) + not_implemented_stratify(stratify) + + cv = StratifiedShuffleSplit(n_splits=cv_params.n_splits, random_state=cv_params.random_state) + splits = list(cv.split(indices, stratify)) + return splits + + @staticmethod + def stratified_k_fold_cross_validation(indices: np.ndarray, stratify: Optional[np.ndarray], + cv_params: Union[Dict[str, Any], CrossValParameters]) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + + cv_params = CrossValParameters(**cv_params) if isinstance(cv_params, dict) else cv_params + CrossValFuncs.input_warning(cv_params) + not_implemented_stratify(stratify) + + cv = StratifiedKFold(n_splits=cv_params.n_splits, random_state=cv_params.random_state) + splits = list(cv.split(indices, stratify)) + return splits + + @staticmethod + def k_fold_cross_validation(indices: np.ndarray, stratify: Optional[np.ndarray], + cv_params: Union[Dict[str, Any], CrossValParameters]) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Standard k fold cross validation. + + :param indices: array of indices to be split + :param n_splits: number of cross validation splits + :return: list of tuples of training and validation indices + """ + cv_params = CrossValParameters(**cv_params) if isinstance(cv_params, dict) else cv_params + CrossValFuncs.input_warning(cv_params) + + cv = KFold(n_splits=cv_params.n_splits, random_state=cv_params.random_state) + splits = list(cv.split(indices)) + return splits + + @staticmethod + def time_series_cross_validation(indices: np.ndarray, stratify: Optional[np.ndarray], + cv_params: Union[Dict[str, Any], CrossValParameters]) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Returns train and validation indices respecting the temporal ordering of the data. + Dummy example: [0, 1, 2, 3] with 3 folds yields + [0] [1] + [0, 1] [2] + [0, 1, 2] [3] + + :param indices: array of indices to be split + :param n_splits: number of cross validation splits + :return: list of tuples of training and validation indices + """ + cv_params = CrossValParameters(**cv_params) if isinstance(cv_params, dict) else cv_params + CrossValFuncs.input_warning(cv_params) + + cv = TimeSeriesSplit(n_splits=cv_params.n_splits) + splits = list(cv.split(indices)) + return splits + + @classmethod + def get_cross_validators(cls, *cross_val_types: Tuple[CrossValTypes]) \ + -> Dict[str, SplitFunc]: + + cross_validators = { + cross_val_type.name: getattr(cls, cross_val_type.name) + for cross_val_type in cross_val_types + } + return cross_validators + + +class HoldOutFuncs(): + @staticmethod + def input_warning(holdout_params: HoldOutParameters): + if not 0 < holdout_params.val_ratio < 1: + raise ValueError(f"val_ratio must be in (0, 1), but got {holdout_params.val_ratio}.") + + @staticmethod + def holdout_validation(indices: np.ndarray, stratify: Optional[np.ndarray], + holdout_params: Union[Dict[str, Any], HoldOutParameters]) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + + HoldOutFuncs.input_warning(holdout_params) + train, val = train_test_split(indices, test_size=holdout_params.val_ratio, + shuffle=False, random_state=holdout_params.random_state) + return [(train, val)] + + @staticmethod + def stratified_holdout_validation(indices: np.ndarray, stratify: Optional[np.ndarray], + holdout_params: Union[Dict[str, Any], HoldOutParameters]) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + + HoldOutFuncs.input_warning(holdout_params) + not_implemented_stratify(stratify) + + train, val = train_test_split(indices, test_size=holdout_params.val_ratio, shuffle=True, + stratify=stratify, random_state=holdout_params.random_state) + return [(train, val)] + + @classmethod + def get_holdout_validators(cls, *holdout_val_types: Tuple[HoldOutTypes]) -> Dict[str, SplitFunc]: + + holdout_validators = { + holdout_val_type.name: getattr(cls, holdout_val_type.name) + for holdout_val_type in holdout_val_types + } + return holdout_validators diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index 65d252852..0229d9167 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -250,6 +250,7 @@ def __init__(self, backend: Backend, raise ValueError('disable_file_output should be either a bool or a list') self.pipeline_class: Optional[Union[BaseEstimator, BasePipeline]] = None + """TODO: info -> NamedTuple""" info: Dict[str, Any] = {'task_type': self.datamanager.task_type, 'output_type': self.datamanager.output_type, 'issparse': self.issparse} @@ -283,6 +284,7 @@ def __init__(self, backend: Backend, self.predict_function = self._predict_proba if self.task_type in TABULAR_TASKS: assert isinstance(self.datamanager, TabularDataset) + """TODO: info -> namedtuple""" info.update({'numerical_columns': self.datamanager.numerical_columns, 'categorical_columns': self.datamanager.categorical_columns}) self.dataset_properties = self.datamanager.get_dataset_properties(get_dataset_requirements(info)) diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index 770625c53..554b01daf 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -163,8 +163,8 @@ def __init__( else: self._get_test_loss = False - self.resampling_strategy = dm.resampling_strategy - self.resampling_strategy_args = dm.resampling_strategy_args + self.splitting_type = dm.splitting_type + self.splitting_params = dm.splitting_params self.search_space_updates = search_space_updates @@ -382,7 +382,7 @@ def run( if ( info is not None - and self.resampling_strategy in ['holdout-iterative-fit', 'cv-iterative-fit'] + and self.splitting_type in ['holdout-iterative-fit', 'cv-iterative-fit'] and status != StatusType.CRASHED ): learning_curve = extract_learning_curve(info) diff --git a/autoPyTorch/optimizer/smbo.py b/autoPyTorch/optimizer/smbo.py index 9a464a1bb..f52d30392 100644 --- a/autoPyTorch/optimizer/smbo.py +++ b/autoPyTorch/optimizer/smbo.py @@ -17,10 +17,11 @@ from smac.utils.io.traj_logging import TrajEntry from autoPyTorch.datasets.base_dataset import BaseDataset -from autoPyTorch.datasets.resampling_strategy import ( +from autoPyTorch.datasets.train_val_split import ( CrossValTypes, - DEFAULT_RESAMPLING_PARAMETERS, - HoldoutValTypes, + CrossValParameters, + HoldOutTypes, + HoldOutParameters ) from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash @@ -78,7 +79,9 @@ def get_smac_object( class AutoMLSMBO(object): - + """TODO + Why do we need splitting_type and splitting_params here as arguments? + What happen if we put different type from BaseDataset?""" def __init__(self, config_space: ConfigSpace.ConfigurationSpace, dataset_name: str, @@ -93,8 +96,8 @@ def __init__(self, pipeline_config: typing.Dict[str, typing.Any], start_num_run: int = 1, seed: int = 1, - resampling_strategy: typing.Union[HoldoutValTypes, CrossValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: typing.Optional[typing.Dict[str, typing.Any]] = None, + splitting_type: typing.Union[HoldOutTypes, CrossValTypes] = HoldOutTypes.holdout_validation, + splitting_params: typing.Optional[typing.Dict[str, typing.Any]] = None, include: typing.Optional[typing.Dict[str, typing.Any]] = None, exclude: typing.Optional[typing.Dict[str, typing.Any]] = None, disable_file_output: typing.List = [], @@ -112,7 +115,7 @@ def __init__(self, tasks in Dask. Args: - config_space (ConfigSpace.ConfigurationSpac): + config_space (ConfigSpace.ConfigurationSpace): The configuration space of the whole process dataset_name (str): The name of the dataset, used to identify the current job @@ -136,10 +139,10 @@ def __init__(self, The ID index to start runs seed (int): To make the run deterministic - resampling_strategy (str): + splitting_type (str): What strategy to use for performance validation - resampling_strategy_args (typing.Optional[typing.Dict[str, typing.Any]]): - Arguments to the resampling strategy -- like number of folds + splitting_params (typing.Optional[typing.Dict[str, typing.Any]]): + Arguments to the splitting type -- like number of folds include (typing.Optional[typing.Dict[str, typing.Any]] = None): Optimal Configuration space modifiers exclude (typing.Optional[typing.Dict[str, typing.Any]] = None): @@ -172,10 +175,15 @@ def __init__(self, self.dask_client = dask_client # Evaluation - self.resampling_strategy = resampling_strategy - if resampling_strategy_args is None: - resampling_strategy_args = DEFAULT_RESAMPLING_PARAMETERS[resampling_strategy] - self.resampling_strategy_args = resampling_strategy_args + self.splitting_type = splitting_type + + if splitting_type is None: + if isinstance(splitting_type, CrossValTypes): + self.splitting_params = CrossValParameters() + else: + self.splitting_params = HoldOutParameters() + else: + self.splitting_params = splitting_params # and a bunch of useful limits self.worst_possible_result = get_cost_of_crash(self.metric) @@ -230,8 +238,8 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None # Initialize some SMAC dependencies - if isinstance(self.resampling_strategy, CrossValTypes): - num_splits = self.resampling_strategy_args['num_splits'] + if isinstance(self.splitting_type, CrossValTypes): + num_splits = self.splitting_params['num_splits'] instances = [[json.dumps({'task_id': self.dataset_name, 'fold': fold_number})] for fold_number in range(num_splits)] diff --git a/autoPyTorch/utils/common.py b/autoPyTorch/utils/common.py index 3143ced11..b3c5880b9 100644 --- a/autoPyTorch/utils/common.py +++ b/autoPyTorch/utils/common.py @@ -11,6 +11,76 @@ from torch.utils.data.dataloader import default_collate +class BaseDict(dict): + """The extension of dict + + This class allows to call value (self[key]) by self.key. + The main intension is to make NamedTuple mutable. + + Example: + class NewDict(BaseDict): + def __init__(self, a: int = 1, b: float = 2.0): + super().__init__(a=a, b=b) + + >>> nd = NewDict() + >>> nd.a, nd.b + (1, 2.0) + >>> nd.c = 3 + >>> nd["d"] = 4 + >>> nd, nd.__dict__ + ({'a': 1, 'b': 2.0, 'c': 3, 'd': 4}, {'a': 1, 'b': 2.0, 'c': 3, 'd': 4}) + """ + def __init__(self, **kwargs): + for key, value in kwargs.items(): + self.__setattr__(key, value) + + def __setattr__(self, name, value): + super().__setattr__(name, value) + super().__setitem__(name, value) + + def __setitem__(self, key, value): + setattr(self, key, value) + super().__setitem__(key, value) + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(key) + + +class BaseNamedTuple(): + """ + A class that expands the NamedTuple package. + This class allows NamedTuple to be used like a dict + by inheriting this class. + Therefore, self must be NamedTuple class. + """ + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise AttributeError(f"NamedTuple does not have the attribute name {key}") + + def pkg_check(self): + if hasattr(self, "_asdict"): + return True + else: + raise AttributeError("The child class of BaseNamedTuple must inherit NamedTuple class.") + + def keys(self): + self.pkg_check() + return self._asdict().keys() + + def values(self): + self.pkg_check() + return self._asdict().values() + + def items(self): + self.pkg_check() + return self._asdict().items() + + class FitRequirement(NamedTuple): """ A class that holds inputs required to fit a pipeline. Also indicates wether diff --git a/autoPyTorch/utils/pipeline.py b/autoPyTorch/utils/pipeline.py index 3cd0d528f..d59600a60 100644 --- a/autoPyTorch/utils/pipeline.py +++ b/autoPyTorch/utils/pipeline.py @@ -1,153 +1,111 @@ # -*- encoding: utf-8 -*- -from typing import Any, Dict, List, Optional +"""TODO: reduce strings as much as possible""" +from typing import Any, Dict, List, Optional, Tuple, NamedTuple, Union from ConfigSpace.configuration_space import ConfigurationSpace from autoPyTorch.constants import ( - CLASSIFICATION_TASKS, - IMAGE_TASKS, - REGRESSION_TASKS, - STRING_TO_TASK_TYPES, - TABULAR_TASKS, + RegressionTypes, + ClassificationTypes, + SupportedTaskTypes ) -from autoPyTorch.pipeline.image_classification import ImageClassificationPipeline -from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline -from autoPyTorch.pipeline.tabular_regression import TabularRegressionPipeline + from autoPyTorch.utils.common import FitRequirement from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates + __all__ = [ 'get_dataset_requirements', 'get_configuration_space' ] -def get_dataset_requirements(info: Dict[str, Any], - include_estimators: Optional[List[str]] = None, - exclude_estimators: Optional[List[str]] = None, - include_preprocessors: Optional[List[str]] = None, - exclude_preprocessors: Optional[List[str]] = None - ) -> List[FitRequirement]: - exclude = dict() - include = dict() - if include_preprocessors is not None and \ - exclude_preprocessors is not None: +class _PipeLineParameters(NamedTuple): + dataset_properties: Dict[str, Any] + include: Dict[str, List[str]] + exclude: Dict[str, List[str]] + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + + +def _check_supported_tasks(task_type: Union[RegressionTypes, ClassificationTypes]) -> None: + if not any(isinstance(task_type, supported_task_type) for supported_task_type in SupportedTaskTypes): + raise TypeError(f"task_type must be supported class type, but got '{type(task_type)}'") + elif not task_type.is_supported(): + raise TypeError(f"The given task_type '{task_type}' is not supported.") + + +def _check_preprocessor(include: Dict[str, Any], + exclude: Dict[str, Any], + include_preprocessors: List[str], + exclude_preprocessors: List[str]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + + if None not in [include_preprocessors, exclude_preprocessors]: raise ValueError('Cannot specify include_preprocessors and ' 'exclude_preprocessors.') elif include_preprocessors is not None: + """TODO: what is include and exclude? Why don't we use NamedTuple?""" include['feature_preprocessor'] = include_preprocessors elif exclude_preprocessors is not None: exclude['feature_preprocessor'] = exclude_preprocessors - task_type: int = STRING_TO_TASK_TYPES[info['task_type']] - if include_estimators is not None and \ - exclude_estimators is not None: + return include, exclude + + +def _check_estimators(task_type: Union[RegressionTypes, ClassificationTypes], + include: Dict[str, Any], + exclude: Dict[str, Any], + include_estimators: List[str], + exclude_estimators: List[str]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + + if None not in [include_estimators, exclude_estimators]: raise ValueError('Cannot specify include_estimators and ' 'exclude_estimators.') - elif include_estimators is not None: - if task_type in CLASSIFICATION_TASKS: - include['classifier'] = include_estimators - elif task_type in REGRESSION_TASKS: - include['regressor'] = include_estimators - else: - raise ValueError(info['task_type']) + + if include_estimators is not None: + include[task_type.task_name] = include_estimators elif exclude_estimators is not None: - if task_type in CLASSIFICATION_TASKS: - exclude['classifier'] = exclude_estimators - elif task_type in REGRESSION_TASKS: - exclude['regressor'] = exclude_estimators - else: - raise ValueError(info['task_type']) - - if task_type in REGRESSION_TASKS: - return _get_regression_dataset_requirements(info, include, exclude) - else: - return _get_classification_dataset_requirements(info, include, exclude) - - -def _get_regression_dataset_requirements(info: Dict[str, Any], include: Dict[str, List[str]], - exclude: Dict[str, List[str]]) -> List[FitRequirement]: - task_type = STRING_TO_TASK_TYPES[info['task_type']] - if task_type in TABULAR_TASKS: - fit_requirements = TabularRegressionPipeline( - dataset_properties=info, - include=include, - exclude=exclude - ).get_dataset_requirements() - return fit_requirements - else: - raise ValueError("Task_type not supported") - - -def _get_classification_dataset_requirements(info: Dict[str, Any], include: Dict[str, List[str]], - exclude: Dict[str, List[str]]) -> List[FitRequirement]: - task_type = STRING_TO_TASK_TYPES[info['task_type']] - - if task_type in TABULAR_TASKS: - return TabularClassificationPipeline( - dataset_properties=info, - include=include, exclude=exclude).\ - get_dataset_requirements() - elif task_type in IMAGE_TASKS: - return ImageClassificationPipeline( - dataset_properties=info, - include=include, exclude=exclude).\ - get_dataset_requirements() - else: - raise ValueError("Task_type not supported") - - -def get_configuration_space(info: Dict[str, Any], + exclude[task_type.task_name] = exclude_estimators + + return include, exclude + + +def get_dataset_requirements(dataset_properties: 'DatasetProperties', # temporal name + include_estimators: Optional[List[str]] = None, + exclude_estimators: Optional[List[str]] = None, + include_preprocessors: Optional[List[str]] = None, + exclude_preprocessors: Optional[List[str]] = None + ) -> List[FitRequirement]: + """TODO: make 'info' (older argument) NamedTuple""" + """TODO: to be compatible with other files using get_dataset_requirements""" + """TODO: DatasetProperties can be merged in BaseDataset in my opinion.""" + include, exclude = dict(), dict() + task_type = dataset_properties.task_type + + _check_supported_tasks(task_type) + + include, exclude = _check_preprocessor(include, exclude, + include_preprocessors, exclude_preprocessors) + include, exclude = _check_estimators(task_type, include, exclude, + include_estimators, exclude_estimators) + + pipeline_params = _PipeLineParameters(dataset_properties=dataset_properties._asdict(), + include=include, exclude=exclude)._asdict() + + return task_type.pipeline(**pipeline_params).get_dataset_requirements() + + +def get_configuration_space(dataset_properties: 'DatasetProperties', # temporal name include: Optional[Dict] = None, exclude: Optional[Dict] = None, search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ) -> ConfigurationSpace: - task_type: int = STRING_TO_TASK_TYPES[info['task_type']] - - if task_type in REGRESSION_TASKS: - return _get_regression_configuration_space(info, - include if include is not None else {}, - exclude if exclude is not None else {}, - search_space_updates=search_space_updates - ) - else: - return _get_classification_configuration_space(info, - include if include is not None else {}, - exclude if exclude is not None else {}, - search_space_updates=search_space_updates - ) - - -def _get_regression_configuration_space(info: Dict[str, Any], include: Dict[str, List[str]], - exclude: Dict[str, List[str]], - search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None - ) -> ConfigurationSpace: - if STRING_TO_TASK_TYPES[info['task_type']] in TABULAR_TASKS: - configuration_space = TabularRegressionPipeline( - dataset_properties=info, - include=include, - exclude=exclude, - search_space_updates=search_space_updates - ).get_hyperparameter_search_space() - return configuration_space - else: - raise ValueError("Task_type not supported") - - -def _get_classification_configuration_space(info: Dict[str, Any], include: Dict[str, List[str]], - exclude: Dict[str, List[str]], - search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None - ) -> ConfigurationSpace: - if STRING_TO_TASK_TYPES[info['task_type']] in TABULAR_TASKS: - pipeline = TabularClassificationPipeline(dataset_properties=info, - include=include, exclude=exclude, - search_space_updates=search_space_updates) - return pipeline.get_hyperparameter_search_space() - elif STRING_TO_TASK_TYPES[info['task_type']] in IMAGE_TASKS: - return ImageClassificationPipeline( - dataset_properties=info, - include=include, exclude=exclude, - search_space_updates=search_space_updates).\ - get_hyperparameter_search_space() - else: - raise ValueError("Task_type not supported") + + task_type = dataset_properties.task_type + + _check_supported_tasks(task_type) + + pipeline_params = _PipeLineParameters(dataset_properties=dataset_properties._asdict(), + include=include, exclude=exclude, + search_space_updates=search_space_updates)._asdict() + + return task_type.pipeline(**pipeline_params).get_hyperparameter_search_space() diff --git a/examples/example_smac_intensify.py b/examples/example_smac_intensify.py index b92c90968..82946b8ab 100644 --- a/examples/example_smac_intensify.py +++ b/examples/example_smac_intensify.py @@ -9,7 +9,7 @@ import sklearn.datasets import sklearn.model_selection -from autoPyTorch.datasets.resampling_strategy import CrossValTypes +from autoPyTorch.datasets.train_val_split import CrossValTypes from autoPyTorch.datasets.tabular_dataset import TabularDataset from autoPyTorch.optimizer.smbo import AutoMLSMBO from autoPyTorch.pipeline.components.training.metrics.utils import get_metrics @@ -97,7 +97,7 @@ def get_data_to_train() -> typing.Tuple[typing.Any, typing.Any, typing.Any, typi datamanager = TabularDataset( X=X_train, Y=y_train, X_test=X_test, Y_test=y_test, - resampling_strategy=CrossValTypes.k_fold_cross_validation) + splitting_type=CrossValTypes.k_fold_cross_validation) backend.save_datamanager(datamanager) # Build a ensemble from the above components diff --git a/test/conftest.py b/test/conftest.py index 195f51e13..3ebe5b562 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -305,15 +305,15 @@ def dataset_traditional_classifier_categorical_only(): X, y = fetch_openml(data_id=40981, return_X_y=True, as_frame=True) categorical_columns = [column for column in X.columns if X[column].dtype.name == 'category'] X = X[categorical_columns] - X, y = X[:200].to_numpy(), y[:200].to_numpy().astype(np.int) + X, y = X[:200].to_numpy(), y[:200].to_numpy().astype(np.int64) return X, y @pytest.fixture def dataset_traditional_classifier_num_categorical(): X, y = fetch_openml(data_id=40981, return_X_y=True, as_frame=True) - y = y.astype(np.int) - X, y = X[:200].to_numpy(), y[:200].to_numpy().astype(np.int) + y = y.astype(np.int64) + X, y = X[:200].to_numpy(), y[:200].to_numpy().astype(np.int64) return X, y diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index ce9a88e2e..8b5cd319c 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -13,9 +13,9 @@ import torch from autoPyTorch.api.tabular_classification import TabularClassificationTask -from autoPyTorch.datasets.resampling_strategy import ( +from autoPyTorch.datasets.train_val_split import ( CrossValTypes, - HoldoutValTypes, + HoldOutTypes, ) from autoPyTorch.datasets.tabular_dataset import TabularDataset @@ -27,10 +27,9 @@ # Test # ======== @pytest.mark.parametrize('openml_id', (40981, )) -@pytest.mark.parametrize('resampling_strategy', (HoldoutValTypes.holdout_validation, - CrossValTypes.k_fold_cross_validation, - )) -def test_classification(openml_id, resampling_strategy, backend): +@pytest.mark.parametrize('splitting_type', (HoldOutTypes.holdout_validation, + CrossValTypes.k_fold_cross_validation)) +def test_classification(openml_id, splitting_type, backend): # Get the data and check that contents of data-manager make sense X, y = sklearn.datasets.fetch_openml( @@ -42,11 +41,11 @@ def test_classification(openml_id, resampling_strategy, backend): datamanager = TabularDataset( X=X_train, Y=y_train, X_test=X_test, Y_test=y_test, - resampling_strategy=resampling_strategy, + splitting_type=splitting_type, dataset_name=str(openml_id), ) assert datamanager.task_type == 'tabular_classification' - expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 3 + expected_num_splits = 1 if splitting_type == HoldOutTypes.holdout_validation else 3 assert len(datamanager.splits) == expected_num_splits # Search for a good configuration @@ -102,14 +101,14 @@ def test_classification(openml_id, resampling_strategy, backend): if os.path.exists(run_key_model_run_dir): break - if resampling_strategy == HoldoutValTypes.holdout_validation: + if splitting_type == HoldOutTypes.holdout_validation: model_file = os.path.join(run_key_model_run_dir, f"{estimator.seed}.{run_key.config_id}.{run_key.budget}.model") assert os.path.exists(model_file), model_file model = estimator._backend.load_model_by_seed_and_id_and_budget( estimator.seed, run_key.config_id, run_key.budget) assert isinstance(model.named_steps['network'].get_network(), torch.nn.Module) - elif resampling_strategy == CrossValTypes.k_fold_cross_validation: + elif splitting_type == CrossValTypes.k_fold_cross_validation: model_file = os.path.join( run_key_model_run_dir, f"{estimator.seed}.{run_key.config_id}.{run_key.budget}.cv_model" @@ -122,7 +121,7 @@ def test_classification(openml_id, resampling_strategy, backend): assert isinstance(model.estimators_[0].named_steps['network'].get_network(), torch.nn.Module) else: - pytest.fail(resampling_strategy) + pytest.fail(splitting_type) # Make sure that predictions on the test data are printed and make sense test_prediction = os.path.join(run_key_model_run_dir, diff --git a/test/test_datasets/test_tabular_dataset.py b/test/test_datasets/test_tabular_dataset.py index dfc72be77..1cae0deca 100644 --- a/test/test_datasets/test_tabular_dataset.py +++ b/test/test_datasets/test_tabular_dataset.py @@ -26,7 +26,7 @@ def runTest(self): self.assertEqual(ds.vtois[0][np.nan], 0) self.assertEqual(ds.vtois[0][pd._libs.NaT], 0) self.assertEqual(ds.vtois[0][pd._libs.missing.NAType()], 0) - self.assertTrue((ds.nan_mask == np.array([[0, 0, 0], [0, 0, 1]], dtype=np.bool)).all()) + self.assertTrue((ds.nan_mask == np.array([[0, 0, 0], [0, 0, 1]], dtype=np.bool8)).all()) class NumpyArrayTest(unittest.TestCase): @@ -41,7 +41,7 @@ def runTest(self): self.assertEqual(ds.vtois[0][np.nan], 0) self.assertEqual(ds.vtois[0][pd._libs.NaT], 0) self.assertEqual(ds.vtois[0][pd._libs.missing.NAType()], 0) - self.assertTrue((ds.nan_mask == np.array([[0, 0, 0], [0, 1, 0]], dtype=np.bool)).all()) + self.assertTrue((ds.nan_mask == np.array([[0, 0, 0], [0, 1, 0]], dtype=np.bool8)).all()) def get_data_to_train() -> typing.Dict[str, typing.Any]: diff --git a/test/test_evaluation/evaluation_util.py b/test/test_evaluation/evaluation_util.py index b61df8643..4e2b0114a 100644 --- a/test/test_evaluation/evaluation_util.py +++ b/test/test_evaluation/evaluation_util.py @@ -11,7 +11,7 @@ import sklearn.model_selection from sklearn import preprocessing -from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes +from autoPyTorch.datasets.train_val_split import HoldOutTypes from autoPyTorch.datasets.tabular_dataset import TabularDataset from autoPyTorch.pipeline.components.training.metrics.metrics import ( accuracy, @@ -131,7 +131,7 @@ def __fit(self, function_handle): raise e -def get_multiclass_classification_datamanager(resampling_strategy=HoldoutValTypes.holdout_validation): +def get_multiclass_classification_datamanager(splitting_type=HoldOutTypes.holdout_validation): X_train, Y_train, X_test, Y_test = get_dataset('iris') indices = list(range(X_train.shape[0])) np.random.seed(1) @@ -142,12 +142,12 @@ def get_multiclass_classification_datamanager(resampling_strategy=HoldoutValType dataset = TabularDataset( X=X_train, Y=Y_train, X_test=X_test, Y_test=Y_test, - resampling_strategy=resampling_strategy + splitting_type=splitting_type ) return dataset -def get_abalone_datamanager(resampling_strategy=HoldoutValTypes.holdout_validation): +def get_abalone_datamanager(splitting_type=HoldOutTypes.holdout_validation): # https://www.openml.org/d/183 X, y = sklearn.datasets.fetch_openml(data_id=183, return_X_y=True, as_frame=False) y = preprocessing.LabelEncoder().fit_transform(y) @@ -158,12 +158,12 @@ def get_abalone_datamanager(resampling_strategy=HoldoutValTypes.holdout_validati dataset = TabularDataset( X=X_train, Y=y_train, X_test=X_test, Y_test=y_test, - resampling_strategy=resampling_strategy + splitting_type=splitting_type ) return dataset -def get_binary_classification_datamanager(resampling_strategy=HoldoutValTypes.holdout_validation): +def get_binary_classification_datamanager(splitting_type=HoldOutTypes.holdout_validation): X_train, Y_train, X_test, Y_test = get_dataset('iris') indices = list(range(X_train.shape[0])) np.random.seed(1) @@ -182,12 +182,12 @@ def get_binary_classification_datamanager(resampling_strategy=HoldoutValTypes.ho dataset = TabularDataset( X=X_train, Y=Y_train, X_test=X_test, Y_test=Y_test, - resampling_strategy=resampling_strategy + splitting_type=splitting_type ) return dataset -def get_regression_datamanager(resampling_strategy=HoldoutValTypes.holdout_validation): +def get_regression_datamanager(splitting_type=HoldOutTypes.holdout_validation): X_train, Y_train, X_test, Y_test = get_dataset('boston') indices = list(range(X_train.shape[0])) np.random.seed(1) @@ -198,12 +198,12 @@ def get_regression_datamanager(resampling_strategy=HoldoutValTypes.holdout_valid dataset = TabularDataset( X=X_train, Y=Y_train, X_test=X_test, Y_test=Y_test, - resampling_strategy=resampling_strategy + splitting_type=splitting_type ) return dataset -def get_500_classes_datamanager(resampling_strategy=HoldoutValTypes.holdout_validation): +def get_500_classes_datamanager(splitting_type=HoldOutTypes.holdout_validation): weights = ([0.002] * 475) + ([0.001] * 25) X, Y = sklearn.datasets.make_classification(n_samples=1000, n_features=20, @@ -224,7 +224,7 @@ def get_500_classes_datamanager(resampling_strategy=HoldoutValTypes.holdout_vali dataset = TabularDataset( X=X[:700], Y=Y[:700], X_test=X[700:], Y_test=Y[710:], - resampling_strategy=resampling_strategy + splitting_type=splitting_type ) return dataset diff --git a/test/test_evaluation/test_train_evaluator.py b/test/test_evaluation/test_train_evaluator.py index 67132285e..7f7bd4bf4 100644 --- a/test/test_evaluation/test_train_evaluator.py +++ b/test/test_evaluation/test_train_evaluator.py @@ -14,7 +14,7 @@ from smac.tae import StatusType -from autoPyTorch.datasets.resampling_strategy import CrossValTypes +from autoPyTorch.datasets.train_val_split import CrossValTypes from autoPyTorch.evaluation.train_evaluator import TrainEvaluator from autoPyTorch.evaluation.utils import read_queue from autoPyTorch.pipeline.base_pipeline import BasePipeline @@ -125,7 +125,7 @@ def test_holdout(self, pipeline_mock): @unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline') def test_cv(self, pipeline_mock): - D = get_binary_classification_datamanager(resampling_strategy=CrossValTypes.k_fold_cross_validation) + D = get_binary_classification_datamanager(splitting_type=CrossValTypes.k_fold_cross_validation) pipeline_mock.predict_proba.side_effect = \ lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1))