diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..743437c7d --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,42 @@ +#see https://github.com/codecov/support/wiki/Codecov-Yaml +codecov: + notify: + require_ci_to_pass: yes + +coverage: + precision: 2 # 2 = xx.xx%, 0 = xx% + round: nearest # how coverage is rounded: down/up/nearest + range: 10...90 # custom range of coverage colors from red -> yellow -> green + status: + # https://codecov.readme.io/v1.0/docs/commit-status + project: + default: + against: auto + target: 70% # specify the target coverage for each commit status + threshold: 50% # allow this little decrease on project + # https://github.com/codecov/support/wiki/Filtering-Branches + # branches: master + if_ci_failed: error + # https://github.com/codecov/support/wiki/Patch-Status + patch: + default: + against: auto + target: 30% # specify the target "X%" coverage to hit + threshold: 50% # allow this much decrease on patch + changes: false + +parsers: + gcov: + branch_detection: + conditional: true + loop: true + macro: false + method: false + javascript: + enable_partials: false + +comment: + layout: header, diff + require_changes: false + behavior: default # update if exists else create new + branches: * diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..cef5b1c29 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,26 @@ +# .coveragerc to control coverage.py +[run] +branch = True + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + +ignore_errors = True + +[html] +directory = coverage_html_report diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 89b4ada60..9f6c13611 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -9,7 +9,10 @@ jobs: strategy: matrix: python-version: [3.6, 3.7, 3.8] - fail-fast: false + include: + - python-version: 3.8 + code-cov: true + fail-fast: false max-parallel: 2 steps: @@ -29,7 +32,7 @@ jobs: echo "::set-output name=BEFORE::$(git status --porcelain -b)" - name: Run tests run: | - if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi + if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml --cov-config=.coveragerc'; fi python -m pytest --forked --durations=20 --timeout=600 --timeout-method=signal -v $codecov test - name: Check for files left behind by test if: ${{ always() }} diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 6c366c563..e1e1ffedc 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -34,6 +34,7 @@ STRING_TO_OUTPUT_TYPES, STRING_TO_TASK_TYPES, ) +from autoPyTorch.data.base_validator import BaseInputValidator from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager @@ -203,6 +204,8 @@ def __init__( self._multiprocessing_context = 'fork' self._dask_client = SingleThreadedClient() + self.InputValidator: Optional[BaseInputValidator] = None + self.search_space_updates = search_space_updates if search_space_updates is not None: if not isinstance(self.search_space_updates, @@ -273,8 +276,8 @@ def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace: include=self.include_components, exclude=self.exclude_components, search_space_updates=self.search_space_updates) - raise Exception("No search space initialised and no dataset passed. " - "Can't create default search space without the dataset") + raise ValueError("No search space initialised and no dataset passed. " + "Can't create default search space without the dataset") def _get_logger(self, name: str) -> PicklableClientLogger: """ diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 9955e706f..9d62f49ad 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -129,7 +129,10 @@ def __init__( if len(self.train_tensors) == 2 and self.train_tensors[1] is not None: self.output_type: str = type_of_target(self.train_tensors[1]) - if STRING_TO_OUTPUT_TYPES[self.output_type] in CLASSIFICATION_OUTPUTS: + if ( + self.output_type in STRING_TO_OUTPUT_TYPES + and STRING_TO_OUTPUT_TYPES[self.output_type] in CLASSIFICATION_OUTPUTS + ): self.output_shape = len(np.unique(self.train_tensors[1])) else: self.output_shape = self.train_tensors[1].shape[-1] if self.train_tensors[1].ndim > 1 else 1 diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index a1e599dd6..ac96c934a 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -162,7 +162,10 @@ def stratified_k_fold_cross_validation(random_state: np.random.RandomState, indices: np.ndarray, **kwargs: Any ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedKFold(n_splits=num_splits, random_state=random_state) + + shuffle = kwargs.get('shuffle', True) + cv = StratifiedKFold(n_splits=num_splits, shuffle=shuffle, + random_state=random_state if not shuffle else None) splits = list(cv.split(indices, kwargs["stratify"])) return splits diff --git a/autoPyTorch/datasets/tabular_dataset.py b/autoPyTorch/datasets/tabular_dataset.py index 19e483612..c178c755c 100644 --- a/autoPyTorch/datasets/tabular_dataset.py +++ b/autoPyTorch/datasets/tabular_dataset.py @@ -24,18 +24,6 @@ ) -class Value2Index(object): - def __init__(self, values: list): - assert all(not (pd.isna(v)) for v in values) - self.values = {v: i for i, v in enumerate(values)} - - def __getitem__(self, item: Any) -> int: - if pd.isna(item): - return 0 - else: - return self.values[item] + 1 - - class TabularDataset(BaseDataset): """ Base class for datasets used in AutoPyTorch diff --git a/autoPyTorch/search_space/__init__.py b/autoPyTorch/search_space/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/autoPyTorch/search_space/search_space.py b/autoPyTorch/search_space/search_space.py deleted file mode 100644 index 5587eff15..000000000 --- a/autoPyTorch/search_space/search_space.py +++ /dev/null @@ -1,153 +0,0 @@ -import typing -from typing import Optional - -import ConfigSpace as cs - - -class SearchSpace: - - hyperparameter_types = { - 'categorical': cs.CategoricalHyperparameter, - 'integer': cs.UniformIntegerHyperparameter, - 'float': cs.UniformFloatHyperparameter, - 'constant': cs.Constant, - } - - @typing.no_type_check - def __init__( - self, - cs_name: str = 'Default Hyperparameter Config', - seed: int = 11, - ): - """Fit the selected algorithm to the training data. - - Args: - cs_name (str): The name of the configuration space. - seed (int): Seed value used for the configuration space. - - Returns: - """ - self._hp_search_space = cs.ConfigurationSpace( - name=cs_name, - seed=seed, - ) - - @typing.no_type_check - def add_hyperparameter( - self, - name: str, - hyperparameter_type: str, - **kwargs, - ): - """Add a new hyperparameter to the configuration space. - - Args: - name (str): The name of the hyperparameter to be added. - hyperparameter_type (str): The type of the hyperparameter to be added. - - Returns: - hyperparameter (cs.Hyperparameter): The hyperparameter that was added - to the hyperparameter search space. - """ - missing_arg = SearchSpace._assert_necessary_arguments_given( - hyperparameter_type, - **kwargs, - ) - - if missing_arg is not None: - raise TypeError(f'A {hyperparameter_type} must have a value for {missing_arg}') - else: - hyperparameter = SearchSpace.hyperparameter_types[hyperparameter_type]( - name=name, - **kwargs, - ) - self._hp_search_space.add_hyperparameter( - hyperparameter - ) - - return hyperparameter - - @staticmethod - @typing.no_type_check - def _assert_necessary_arguments_given( - hyperparameter_type: str, - **kwargs, - ) -> Optional[str]: - """Assert that given a particular hyperparameter type, all the - necessary arguments are given to create the hyperparameter. - - Args: - hyperparameter_type (str): The type of the hyperparameter to be added. - - Returns: - missing_argument (str|None): The argument that is missing - to create the given hyperparameter. - """ - necessary_args = { - 'categorical': {'choices', 'default_value'}, - 'integer': {'lower', 'upper', 'default', 'log'}, - 'float': {'lower', 'upper', 'default', 'log'}, - 'constant': {'value'}, - } - - hp_necessary_args = necessary_args[hyperparameter_type] - for hp_necessary_arg in hp_necessary_args: - if hp_necessary_arg not in kwargs: - return hp_necessary_arg - - return None - - @typing.no_type_check - def set_parent_hyperperparameter( - self, - child_hp, - parent_hp, - parent_value, - ): - """Activate the child hyperparameter on the search space only if the - parent hyperparameter takes a particular value. - - Args: - child_hp (cs.Hyperparameter): The child hyperparameter to be added. - parent_hp (cs.Hyperparameter): The parent hyperparameter to be considered. - parent_value (str|float|int): The value of the parent hyperparameter for when the - child hyperparameter will be added to the search space. - - Returns: - """ - self._hp_search_space.add_condition( - cs.EqualsCondition( - child_hp, - parent_hp, - parent_value, - ) - ) - - @typing.no_type_check - def add_configspace_condition( - self, - child_hp, - parent_hp, - configspace_condition, - value, - ): - """Add a condition on the chi - - Args: - child_hp (cs.Hyperparameter): The child hyperparameter to be added. - parent_hp (cs.Hyperparameter): The parent hyperparameter to be considered. - configspace_condition (cs.AbstractCondition): The condition to be fullfilled - by the parent hyperparameter. A list of all the possible conditions can be - found at ConfigSpace/conditions.py. - value (str|float|int|list): The value of the parent hyperparameter to be matched - in the condition. value needs to be a list only for the InCondition. - - Returns: - """ - self._hp_search_space.add_condition( - configspace_condition( - child_hp, - parent_hp, - value, - ) - ) diff --git a/setup.py b/setup.py index a8522a8dc..3a0a0276b 100755 --- a/setup.py +++ b/setup.py @@ -48,7 +48,10 @@ "codecov", "pep8", "mypy", - "openml" + "openml", + "emcee", + "scikit-optimize", + "pyDOE", ], "examples": [ "matplotlib", diff --git a/test/conftest.py b/test/conftest.py index cdaf53703..aba62122f 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,3 +1,4 @@ +import logging.handlers import os import re import shutil @@ -299,6 +300,7 @@ def get_fit_dictionary(X, y, validator, backend): 'metrics_during_training': True, 'split_id': 0, 'backend': backend, + 'logger_port': logging.handlers.DEFAULT_TCP_LOGGING_PORT, } backend.save_datamanager(datamanager) return fit_dictionary diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index aeab572a5..a0752db25 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -6,6 +6,7 @@ import unittest from test.test_api.utils import dummy_do_dummy_prediction, dummy_eval_function, dummy_traditional_classification +import ConfigSpace as CS from ConfigSpace.configuration_space import Configuration import numpy as np @@ -17,6 +18,7 @@ import sklearn import sklearn.datasets +from sklearn.base import BaseEstimator from sklearn.base import clone from sklearn.ensemble import VotingClassifier, VotingRegressor @@ -31,6 +33,7 @@ ) from autoPyTorch.datasets.tabular_dataset import TabularDataset from autoPyTorch.optimizer.smbo import AutoMLSMBO +from autoPyTorch.pipeline.components.setup.traditional_ml.classifier_models import _classifiers from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy @@ -57,7 +60,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl X, y = X.iloc[:n_samples], y.iloc[:n_samples] X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - X, y, random_state=1) + X, y, random_state=42) include = None # for python less than 3.7, learned entity embedding @@ -69,7 +72,8 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl backend=backend, resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args, - include_components=include + include_components=include, + seed=42, ) with unittest.mock.patch.object(estimator, '_do_dummy_prediction', new=dummy_do_dummy_prediction): @@ -77,8 +81,8 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, optimize_metric='accuracy', - total_walltime_limit=30, - func_eval_time_limit_secs=5, + total_walltime_limit=40, + func_eval_time_limit_secs=10, enable_traditional_pipeline=False, ) @@ -98,15 +102,15 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl assert len(loaded_datamanager.train_tensors) == len(estimator.dataset.train_tensors) expected_files = [ - 'smac3-output/run_1/configspace.json', - 'smac3-output/run_1/runhistory.json', - 'smac3-output/run_1/scenario.txt', - 'smac3-output/run_1/stats.json', - 'smac3-output/run_1/train_insts.txt', - 'smac3-output/run_1/trajectory.json', + 'smac3-output/run_42/configspace.json', + 'smac3-output/run_42/runhistory.json', + 'smac3-output/run_42/scenario.txt', + 'smac3-output/run_42/stats.json', + 'smac3-output/run_42/train_insts.txt', + 'smac3-output/run_42/trajectory.json', '.autoPyTorch/datamanager.pkl', '.autoPyTorch/ensemble_read_preds.pkl', - '.autoPyTorch/start_time_1', + '.autoPyTorch/start_time_42', '.autoPyTorch/ensemble_history.json', '.autoPyTorch/ensemble_read_losses.pkl', '.autoPyTorch/true_targets_ensemble.npy', @@ -187,9 +191,12 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl assert len(estimator.ensemble_.identifiers_) == len(estimator.ensemble_.weights_) y_pred = estimator.predict(X_test) - assert np.shape(y_pred)[0] == np.shape(X_test)[0] + # Make sure that predict proba has the expected shape + probabilites = estimator.predict_proba(X_test) + assert np.shape(probabilites) == (np.shape(X_test)[0], 2) + score = estimator.score(y_pred, y_test) assert 'accuracy' in score @@ -215,6 +222,12 @@ def test_tabular_classification(openml_id, resampling_strategy, backend, resampl restored_estimator = pickle.load(f) restored_estimator.predict(X_test) + # Test refit on dummy data + estimator.refit(dataset=backend.load_datamanager()) + + # Make sure that a configuration space is stored in the estimator + assert isinstance(estimator.get_search_space(), CS.ConfigurationSpace) + @pytest.mark.parametrize('openml_name', ("boston", )) @unittest.mock.patch('autoPyTorch.evaluation.train_evaluator.eval_function', @@ -257,7 +270,8 @@ def test_tabular_regression(openml_name, resampling_strategy, backend, resamplin backend=backend, resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args, - include_components=include + include_components=include, + seed=42, ) with unittest.mock.patch.object(estimator, '_do_dummy_prediction', new=dummy_do_dummy_prediction): @@ -265,8 +279,8 @@ def test_tabular_regression(openml_name, resampling_strategy, backend, resamplin X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, optimize_metric='r2', - total_walltime_limit=30, - func_eval_time_limit_secs=5, + total_walltime_limit=40, + func_eval_time_limit_secs=10, enable_traditional_pipeline=False, ) @@ -286,15 +300,15 @@ def test_tabular_regression(openml_name, resampling_strategy, backend, resamplin assert len(loaded_datamanager.train_tensors) == len(estimator.dataset.train_tensors) expected_files = [ - 'smac3-output/run_1/configspace.json', - 'smac3-output/run_1/runhistory.json', - 'smac3-output/run_1/scenario.txt', - 'smac3-output/run_1/stats.json', - 'smac3-output/run_1/train_insts.txt', - 'smac3-output/run_1/trajectory.json', + 'smac3-output/run_42/configspace.json', + 'smac3-output/run_42/runhistory.json', + 'smac3-output/run_42/scenario.txt', + 'smac3-output/run_42/stats.json', + 'smac3-output/run_42/train_insts.txt', + 'smac3-output/run_42/trajectory.json', '.autoPyTorch/datamanager.pkl', '.autoPyTorch/ensemble_read_preds.pkl', - '.autoPyTorch/start_time_1', + '.autoPyTorch/start_time_42', '.autoPyTorch/ensemble_history.json', '.autoPyTorch/ensemble_read_losses.pkl', '.autoPyTorch/true_targets_ensemble.npy', @@ -459,6 +473,11 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): estimator._disable_file_output = [] estimator._all_supported_metrics = False + with pytest.raises(ValueError, match=r".*Dummy prediction failed with run state.*"): + with unittest.mock.patch('autoPyTorch.evaluation.train_evaluator.eval_function') as dummy: + dummy.side_effect = MemoryError + estimator._do_dummy_prediction() + estimator._do_dummy_prediction() # Ensure that the dummy predictions are not in the current working @@ -639,3 +658,77 @@ def test_get_incumbent_results(dataset_name, backend, include_traditional): if not include_traditional: assert results['configuration_origin'] != 'traditional' + + +# TODO: Make faster when https://github.com/automl/Auto-PyTorch/pull/223 is incorporated +@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True) +def test_do_traditional_pipeline(fit_dictionary_tabular): + backend = fit_dictionary_tabular['backend'] + estimator = TabularClassificationTask( + backend=backend, + resampling_strategy=HoldoutValTypes.holdout_validation, + ensemble_size=0, + ) + + # Setup pre-requisites normally set by search() + estimator._create_dask_client() + estimator._metric = accuracy + estimator._logger = estimator._get_logger('test') + estimator._memory_limit = 5000 + estimator._time_for_task = 60 + estimator._disable_file_output = [] + estimator._all_supported_metrics = False + + estimator._do_traditional_prediction(time_left=60, func_eval_time_limit_secs=30) + + # The models should not be on the current directory + assert not os.path.exists(os.path.join(os.getcwd(), '.autoPyTorch')) + + # Then we should have fitted 5 classifiers + # Maybe some of them fail (unlikely, but we do not control external API) + # but we want to make this test robust + at_least_one_model_checked = False + for i in range(2, 7): + pred_path = os.path.join( + backend.temporary_directory, '.autoPyTorch', 'runs', f"1_{i}_50.0", + f"predictions_ensemble_1_{i}_50.0.npy" + ) + if not os.path.exists(pred_path): + continue + + model_path = os.path.join(backend.temporary_directory, + '.autoPyTorch', + 'runs', f"1_{i}_50.0", + f"1.{i}.50.0.model") + + # Make sure the dummy model complies with scikit learn + # get/set params + assert os.path.exists(model_path) + with open(model_path, 'rb') as model_handler: + model = pickle.load(model_handler) + clone(model) + assert model.config == list(_classifiers.keys())[i - 2] + at_least_one_model_checked = True + if not at_least_one_model_checked: + pytest.fail("Not even one single traditional pipeline was fitted") + + estimator._close_dask_client() + estimator._clean_logger() + + del estimator + + +@pytest.mark.parametrize("api_type", [TabularClassificationTask, TabularRegressionTask]) +def test_unsupported_msg(api_type): + api = api_type() + with pytest.raises(ValueError, match=r".*is only supported after calling search. Kindly .*"): + api.predict(np.ones((10, 10))) + + +@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True) +@pytest.mark.parametrize("api_type", [TabularClassificationTask, TabularRegressionTask]) +def test_build_pipeline(api_type, fit_dictionary_tabular): + api = api_type() + pipeline = api.build_pipeline(fit_dictionary_tabular['dataset_properties']) + assert isinstance(pipeline, BaseEstimator) + assert len(pipeline.steps) > 0 diff --git a/test/test_api/test_base_api.py b/test/test_api/test_base_api.py new file mode 100644 index 000000000..95ab6a0e4 --- /dev/null +++ b/test/test_api/test_base_api.py @@ -0,0 +1,89 @@ +import logging +import re +import unittest +from unittest.mock import MagicMock + +import numpy as np + +import pytest + +from autoPyTorch.api.base_task import BaseTask, _pipeline_predict +from autoPyTorch.constants import TABULAR_CLASSIFICATION, TABULAR_REGRESSION +from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline + + +# ==== +# Test +# ==== +@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True) +def test_nonsupported_arguments(fit_dictionary_tabular): + with pytest.raises(ValueError, match=r".*Expected search space updates to be of instance.*"): + api = BaseTask(search_space_updates='None') + + api = BaseTask() + with pytest.raises(ValueError, match=r".*Invalid configuration arguments given.*"): + api.set_pipeline_config(unsupported=True) + with pytest.raises(ValueError, match=r".*No search space initialised and no dataset.*"): + api.get_search_space() + api.resampling_strategy = None + with pytest.raises(ValueError, match=r".*Resampling strategy is needed to determine.*"): + api._load_models() + api.resampling_strategy = unittest.mock.MagicMock() + with pytest.raises(ValueError, match=r".*Providing a metric to AutoPytorch is required.*"): + api._load_models() + api.ensemble_ = unittest.mock.MagicMock() + with pytest.raises(ValueError, match=r".*No metric found. Either fit/search has not been.*"): + api.score(np.ones(10), np.ones(10)) + api.task_type = None + api._metric = MagicMock() + with pytest.raises(ValueError, match=r".*AutoPytorch failed to infer a task type*"): + api.score(np.ones(10), np.ones(10)) + api._metric = unittest.mock.MagicMock() + with pytest.raises(ValueError, match=r".*No valid model found in run history.*"): + api._load_models() + dataset = fit_dictionary_tabular['backend'].load_datamanager() + with pytest.raises(ValueError, match=r".*Incompatible dataset entered for current task.*"): + api._search('accuracy', dataset) + + def returnfalse(): + return False + + api._load_models = returnfalse + with pytest.raises(ValueError, match=r".*No ensemble found. Either fit has not yet.*"): + api.predict(np.ones((10, 10))) + with pytest.raises(ValueError, match=r".*No ensemble found. Either fit has not yet.*"): + api.predict(np.ones((10, 10))) + + +def test_pipeline_predict_function(): + X = np.ones((10, 10)) + pipeline = MagicMock() + pipeline.predict.return_value = np.full((10,), 3) + pipeline.predict_proba.return_value = np.full((10, 2), 3) + + # First handle the classification case + task = TABULAR_CLASSIFICATION + with pytest.raises(ValueError, match='prediction probability not within'): + _pipeline_predict(pipeline, X, 5, logging.getLogger, task) + pipeline.predict_proba.return_value = np.zeros((10, 2)) + predictions = _pipeline_predict(pipeline, X, 5, logging.getLogger(), task) + assert np.shape(predictions) == (10, 2) + + task = TABULAR_REGRESSION + predictions = _pipeline_predict(pipeline, X, 5, logging.getLogger(), task) + assert np.shape(predictions) == (10,) + # Trigger warning msg with different shape for prediction + pipeline.predict.return_value = np.full((12,), 3) + predictions = _pipeline_predict(pipeline, X, 5, logging.getLogger(), task) + + +@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True) +def test_show_models(fit_dictionary_tabular): + api = BaseTask() + api.ensemble_ = MagicMock() + api.models_ = [TabularClassificationPipeline(dataset_properties=fit_dictionary_tabular['dataset_properties'])] + api.ensemble_.get_models_with_weights.return_value = [(1.0, api.models_[0])] + # Expect the default configuration + expected = (r"0\s+|\s+SimpleImputer,OneHotEncoder,NoScaler,NoFeaturePreprocessing\s+" + r"|\s+no embedding,ShapedMLPBackbone,FullyConnectedHead,nn.Sequential\s+|\s+1") + assert re.search(expected, api.show_models()) is not None diff --git a/test/test_datasets/test_resampling_strategies.py b/test/test_datasets/test_resampling_strategies.py new file mode 100644 index 000000000..7f14275a3 --- /dev/null +++ b/test/test_datasets/test_resampling_strategies.py @@ -0,0 +1,42 @@ +import numpy as np + +from autoPyTorch.datasets.resampling_strategy import CrossValFuncs, HoldOutFuncs + + +def test_holdoutfuncs(): + split = HoldOutFuncs() + X = np.arange(10) + y = np.ones(10) + # Create a minority class + y[:2] = 0 + train, val = split.holdout_validation(0, 0.5, X, shuffle=False) + assert len(train) == len(val) == 5 + + # No shuffling + np.testing.assert_array_equal(X, np.arange(10)) + + # Make sure the stratified version splits the minority class + train, val = split.stratified_holdout_validation(0, 0.5, X, stratify=y) + assert 0 in y[val] + assert 0 in y[train] + + +def test_crossvalfuncs(): + split = CrossValFuncs() + X = np.arange(100) + y = np.ones(100) + # Create a minority class + y[:11] = 0 + splits = split.shuffle_split_cross_validation(0, 10, X) + assert len(splits) == 10 + assert all([len(s[1]) == 10 for s in splits]) + + # Make sure the stratified version splits the minority class + splits = split.stratified_shuffle_split_cross_validation(0, 10, X, stratify=y) + assert len(splits) == 10 + assert all([0 in y[s[1]] for s in splits]) + + # + splits = split.stratified_k_fold_cross_validation(0, 10, X, stratify=y) + assert len(splits) == 10 + assert all([0 in y[s[1]] for s in splits]) diff --git a/test/test_datasets/test_tabular_dataset.py b/test/test_datasets/test_tabular_dataset.py index ab0d09b9b..409e6bdec 100644 --- a/test/test_datasets/test_tabular_dataset.py +++ b/test/test_datasets/test_tabular_dataset.py @@ -1,5 +1,8 @@ +import numpy as np + import pytest +from autoPyTorch.datasets.tabular_dataset import TabularDataset from autoPyTorch.utils.pipeline import get_dataset_requirements @@ -38,3 +41,8 @@ def test_get_dataset_properties(backend, fit_dictionary_tabular): assert datamanager.train_tensors[0].shape == fit_dictionary_tabular['X_train'].shape assert datamanager.train_tensors[1].shape == fit_dictionary_tabular['y_train'].shape assert datamanager.task_type == 'tabular_classification' + + +def test_not_supported(): + with pytest.raises(ValueError, match=r".*A feature validator is required to build.*"): + TabularDataset(np.ones(10), np.ones(10)) diff --git a/test/test_evaluation/test_evaluation.py b/test/test_evaluation/test_evaluation.py index 9afa8969f..e17eae6af 100644 --- a/test/test_evaluation/test_evaluation.py +++ b/test/test_evaluation/test_evaluation.py @@ -10,12 +10,15 @@ import pynisher +import pytest + from smac.runhistory.runhistory import RunInfo from smac.stats.stats import Stats from smac.tae import StatusType +from smac.utils.constants import MAXINT from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash -from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy +from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy, log_loss this_directory = os.path.dirname(__file__) sys.path.append(this_directory) @@ -391,3 +394,8 @@ def test_silent_exception_in_target_function(self): self.assertNotIn('exitcode', info[1].additional_info) self.assertNotIn('exit_status', info[1].additional_info) self.assertNotIn('traceback', info[1]) + + +@pytest.mark.parametrize("metric,expected", [(accuracy, 1.0), (log_loss, MAXINT)]) +def test_get_cost_of_crash(metric, expected): + assert get_cost_of_crash(metric) == expected diff --git a/test/test_pipeline/components/training/base.py b/test/test_pipeline/components/training/base.py index 98ab27b31..b4db199e1 100644 --- a/test/test_pipeline/components/training/base.py +++ b/test/test_pipeline/components/training/base.py @@ -1,6 +1,7 @@ import logging from sklearn.datasets import make_classification, make_regression +from sklearn.preprocessing import StandardScaler import torch @@ -34,9 +35,11 @@ def prepare_trainer(self, n_repeated=0, n_classes=2, n_clusters_per_class=2, + class_sep=3.0, shuffle=True, random_state=0 ) + X = StandardScaler().fit_transform(X) X = torch.tensor(X, dtype=torch.float) y = torch.tensor(y, dtype=torch.long) output_type = BINARY @@ -52,6 +55,7 @@ def prepare_trainer(self, shuffle=True, random_state=0 ) + X = StandardScaler().fit_transform(X) X = torch.tensor(X, dtype=torch.float) y = torch.tensor(y, dtype=torch.float) # normalize targets for regression since NNs are better when predicting small outputs diff --git a/test/test_pipeline/components/training/test_image_data_loader.py b/test/test_pipeline/components/training/test_image_data_loader.py index 76023fec3..af70cf77b 100644 --- a/test/test_pipeline/components/training/test_image_data_loader.py +++ b/test/test_pipeline/components/training/test_image_data_loader.py @@ -8,21 +8,28 @@ ) -class TestFeatureDataLoader(unittest.TestCase): - def test_build_transform(self): - """ - Makes sure a proper composition is created - """ - loader = ImageDataLoader() +def test_imageloader_build_transform(): + """ + Makes sure a proper composition is created + """ + loader = ImageDataLoader() - fit_dictionary = dict() - fit_dictionary['dataset_properties'] = dict() - fit_dictionary['dataset_properties']['is_small_preprocess'] = unittest.mock.Mock(()) - fit_dictionary['image_augmenter'] = unittest.mock.Mock() + fit_dictionary = dict() + fit_dictionary['dataset_properties'] = dict() + fit_dictionary['dataset_properties']['is_small_preprocess'] = unittest.mock.Mock(()) + fit_dictionary['image_augmenter'] = unittest.mock.Mock() + fit_dictionary['preprocess_transforms'] = unittest.mock.Mock() - compose = loader.build_transform(fit_dictionary, mode='train') + compose = loader.build_transform(fit_dictionary, mode='train') - self.assertIsInstance(compose, torchvision.transforms.Compose) + assert isinstance(compose, torchvision.transforms.Compose) - # We expect to tensor and image augmenter - self.assertEqual(len(compose.transforms), 2) + # We expect to tensor and image augmenter + assert len(compose.transforms) == 2 + + compose = loader.build_transform(fit_dictionary, mode='test') + assert isinstance(compose, torchvision.transforms.Compose) + assert len(compose.transforms) == 2 + + # Check the expected error msgs + loader._check_transform_requirements(fit_dictionary) diff --git a/test/test_pipeline/test_traditional_pipeline.py b/test/test_pipeline/test_traditional_pipeline.py new file mode 100644 index 000000000..96b41302a --- /dev/null +++ b/test/test_pipeline/test_traditional_pipeline.py @@ -0,0 +1,46 @@ +import ConfigSpace as CS + +import numpy as np + +import pytest + +from autoPyTorch.pipeline.components.setup.traditional_ml.classifier_models import _classifiers +from autoPyTorch.pipeline.traditional_tabular_classification import ( + TraditionalTabularClassificationPipeline, +) + + +@pytest.mark.parametrize("fit_dictionary_tabular", + ['classification_numerical_and_categorical'], indirect=True) +def test_traditional_tabular_pipeline(fit_dictionary_tabular): + pipeline = TraditionalTabularClassificationPipeline( + dataset_properties=fit_dictionary_tabular['dataset_properties'] + ) + assert pipeline._get_estimator_hyperparameter_name() == "tabular_classifier" + cs = pipeline.get_hyperparameter_search_space() + assert isinstance(cs, CS.ConfigurationSpace) + config = cs.sample_configuration() + assert config['model_trainer:tabular_classifier:classifier'] in _classifiers + assert pipeline.get_pipeline_representation() == { + 'Preprocessing': 'None', + 'Estimator': 'TabularClassifier', + } + + +@pytest.mark.parametrize("fit_dictionary_tabular", + ['classification_numerical_and_categorical'], indirect=True) +def test_traditional_tabular_pipeline_predict(fit_dictionary_tabular): + pipeline = TraditionalTabularClassificationPipeline( + dataset_properties=fit_dictionary_tabular['dataset_properties'] + ) + assert pipeline._get_estimator_hyperparameter_name() == "tabular_classifier" + config = pipeline.get_hyperparameter_search_space().get_default_configuration() + pipeline.set_hyperparameters(config) + pipeline.fit(fit_dictionary_tabular) + prediction = pipeline.predict(fit_dictionary_tabular['X_train']) + assert np.shape(fit_dictionary_tabular['X_train'])[0] == prediction.shape[0] + assert prediction.shape[1] == 1 + prediction = pipeline.predict(fit_dictionary_tabular['X_train'], batch_size=5) + assert np.shape(fit_dictionary_tabular['X_train'])[0] == prediction.shape[0] + prediction = pipeline.predict_proba(fit_dictionary_tabular['X_train'], batch_size=5) + assert np.shape(fit_dictionary_tabular['X_train'])[0] == prediction.shape[0]