diff --git a/.appveyor.yml b/.appveyor.yml index 3e0941ad..745563c2 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -21,7 +21,7 @@ install: - conda info -a - conda create -q -n test-environment python=%PYTHON_VERSION% numpy scipy scikit-learn nose cython pandas - activate test-environment - - pip install deap tqdm update_checker pypiwin32 stopit + - pip install deap tqdm update_checker pypiwin32 stopit dask[delayed] git+https://github.com/dask/dask-ml test_script: diff --git a/ci/.travis_install.sh b/ci/.travis_install.sh index 212f1ebf..3d8b3f82 100755 --- a/ci/.travis_install.sh +++ b/ci/.travis_install.sh @@ -54,6 +54,8 @@ pip install update_checker pip install tqdm pip install stopit pip install xgboost +pip install dask[delayed] +pip install git+https://github.com/dask/dask-ml # TODO: Change to >=0.9 when released if [[ "$COVERAGE" == "true" ]]; then pip install coverage coveralls diff --git a/docs_sources/using.md b/docs_sources/using.md index b06b84e2..9482678c 100644 --- a/docs_sources/using.md +++ b/docs_sources/using.md @@ -555,6 +555,47 @@ rmtree(cachedir) **Note: TPOT does NOT clean up memory caches if users set a custom directory path or Memory object. We recommend that you clean up the memory caches when you don't need it anymore.** +# Parallel Training + +Internally, TPOT uses [joblib](http://joblib.readthedocs.io/) to fit estimators in parallel. +This is the same parallelization framework used by scikit-learn. + +When you specify ``n_jobs``, TPOT will use ``n_jobs`` processes to fit models in parallel. +For large problems, you can distribute the work on a [Dask](http://dask.pydata.org/en/latest/) cluster. +There are two ways to achieve this. + +First, you can specify the ``use_dask`` keyword when you create the TPOT estimator. + +```python +estimator = TPOTEstimator(n_jobs=-1, use_dask=True +``` + +This will use use all the workers on your cluster to do the training, and use [Dask-ML's pipeline rewriting](https://dask-ml.readthedocs.io/en/latest/hyper-parameter-search.html#avoid-repeated-work) to avoid re-fitting estimators multiple times on the same set of data. +It will provide fine-grained diagnostics in the [distributed scheduler UI](https://distributed.readthedocs.io/en/latest/web.html). + +Alternatively, Dask implements a joblib backend. +You can instruct TPOT to use the distribued backend during training by specifying a ``joblib.parallel_backend``: + +```python +from sklearn.externals import joblib +import distributed.joblib +from dask.distributed import Client + +# connect to the cluster +client = Client('schedueler-address') + +# create the estimator normally +estimator = TPOTClassifier(n_jobs=-1) + +# perform the fit in this context manager +with joblib.parallel_backend("dask"): + estimator.fit(X, y) +``` + +See [dask's distributed joblib integration](https://distributed.readthedocs.io/en/latest/joblib.html) for more. + +We recommend using the `use_dask` keyword. + # Crash/freeze issue with n_jobs > 1 under OSX or Linux TPOT supports parallel computing for speeding up the optimization process, but it may crash/freeze with n_jobs > 1 under OSX or Linux [as scikit-learn does](http://scikit-learn.org/stable/faq.html#why-do-i-sometime-get-a-crash-freeze-with-n-jobs-1-under-osx-or-linux), especially with large datasets. diff --git a/setup.py b/setup.py index 795d6893..049cbe73 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,10 @@ def calculate_version(): extras_require={ 'xgboost': ['xgboost==0.6a2'], 'skrebate': ['skrebate>=0.3.4'], - 'mdr': ['scikit-mdr>=0.4.4'] + 'mdr': ['scikit-mdr>=0.4.4'], + 'dask': ['dask>=0.18.2', + 'distributed>=1.22.1', + 'dask-ml>=0.8.0'], }, classifiers=[ 'Intended Audience :: Science/Research', diff --git a/tests/driver_tests.py b/tests/driver_tests.py index c39bfba8..3204012a 100644 --- a/tests/driver_tests.py +++ b/tests/driver_tests.py @@ -27,19 +27,23 @@ import sys from os import remove, path from contextlib import contextmanager +from distutils.version import LooseVersion try: from StringIO import StringIO except ImportError: from io import StringIO +import nose import numpy as np import pandas as pd +import sklearn from tpot.driver import positive_integer, float_range, _get_arg_parser, _print_args, _read_data_file, load_scoring_function, tpot_driver from nose.tools import assert_raises, assert_equal, assert_in from unittest import TestCase + @contextmanager def captured_output(): new_out, new_err = StringIO(), StringIO() @@ -169,6 +173,12 @@ def test_driver_4(): def test_driver_5(): """Assert that the tpot_driver() in TPOT driver outputs normal result with exported python file and verbosity = 0.""" + + # Catch FutureWarning https://github.com/scikit-learn/scikit-learn/issues/11785 + if (np.__version__ >= LooseVersion("1.15.0") and + sklearn.__version__ <= LooseVersion("0.20.0")): + raise nose.SkipTest("Warning raised by scikit-learn") + args_list = [ 'tests/tests.csv', '-is', ',', diff --git a/tests/test_dask_based.py b/tests/test_dask_based.py new file mode 100644 index 00000000..2aaa983e --- /dev/null +++ b/tests/test_dask_based.py @@ -0,0 +1,78 @@ +"""Tests that ensure the dask-based fit matches. + +https://github.com/DEAP/deap/issues/75 +""" +import unittest + +import nose +from sklearn.datasets import make_classification + +from tpot import TPOTClassifier + +try: + import dask # noqa + import dask_ml # noqa +except ImportError: + raise nose.SkipTest() + + +class TestDaskMatches(unittest.TestCase): + + def test_dask_matches(self): + with dask.config.set(scheduler='single-threaded'): + for n_jobs in [-1]: + X, y = make_classification(random_state=0) + a = TPOTClassifier( + generations=2, + population_size=5, + cv=3, + random_state=0, + n_jobs=n_jobs, + use_dask=False, + ) + b = TPOTClassifier( + generations=2, + population_size=5, + cv=3, + random_state=0, + n_jobs=n_jobs, + use_dask=True, + ) + b.fit(X, y) + a.fit(X, y) + + self.assertEqual(a.score(X, y), b.score(X, y)) + self.assertEqual(a.pareto_front_fitted_pipelines_.keys(), + b.pareto_front_fitted_pipelines_.keys()) + self.assertEqual(a.evaluated_individuals_, + b.evaluated_individuals_) + + def test_handles_errors(self): + X, y = make_classification(n_samples=5) + + tpot_config = { + 'sklearn.neighbors.KNeighborsClassifier': { + 'n_neighbors': [1, 100], + } + } + + a = TPOTClassifier( + generations=2, + population_size=5, + cv=3, + random_state=0, + n_jobs=-1, + config_dict=tpot_config, + use_dask=False, + ) + b = TPOTClassifier( + generations=2, + population_size=5, + cv=3, + random_state=0, + n_jobs=-1, + config_dict=tpot_config, + use_dask=True, + ) + a.fit(X, y) + b.fit(X, y) diff --git a/tpot/base.py b/tpot/base.py index 2b15cc54..821d8572 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -124,7 +124,8 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, random_state=None, config_dict=None, warm_start=False, memory=None, periodic_checkpoint_folder=None, early_stop=None, - verbosity=0, disable_update_check=False): + verbosity=0, disable_update_check=False, + use_dask=False): """Set up the genetic programming algorithm for pipeline optimization. Parameters @@ -252,6 +253,14 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, A setting of 2 or higher will add a progress bar during the optimization procedure. disable_update_check: bool, optional (default: False) Flag indicating whether the TPOT version checker should be disabled. + use_dask : bool, default False + Whether to use Dask-ML's pipeline optimiziations. This avoid re-fitting + the same estimator on the same split of data multiple times. It + will also provide more detailed diagnostics when using Dask's + distributed scheduler. + + See `avoid repeated work `__ + for more. Returns ------- @@ -286,6 +295,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None, self._last_optimized_pareto_front_n_gens = 0 self.memory = memory self._memory = None # initial Memory setting for sklearn pipeline + self.use_dask = use_dask # dont save periodic pipelines more often than this self._output_best_pipeline_period_seconds = 30 @@ -1190,12 +1200,13 @@ def _evaluate_individuals(self, individuals, features, target, sample_weight=Non scoring_function=self.scoring_function, sample_weight=sample_weight, groups=groups, - timeout=self.max_eval_time_seconds + timeout=self.max_eval_time_seconds, + use_dask=self.use_dask, ) result_score_list = [] # Don't use parallelization if n_jobs==1 - if self.n_jobs == 1: + if self.n_jobs == 1 and not self.use_dask: for sklearn_pipeline in sklearn_pipeline_list: self._stop_by_max_time_mins() val = partial_wrapped_cross_val_score(sklearn_pipeline=sklearn_pipeline) @@ -1203,15 +1214,34 @@ def _evaluate_individuals(self, individuals, features, target, sample_weight=Non else: # chunk size for pbar update # chunk size is min of cpu_count * 2 and n_jobs * 4 - chunk_size = min(cpu_count()*2, self.n_jobs*4) - for chunk_idx in range(0, len(sklearn_pipeline_list), chunk_size): - self._stop_by_max_time_mins() - parallel = Parallel(n_jobs=self.n_jobs, verbose=0, pre_dispatch='2*n_jobs') - tmp_result_scores = parallel(delayed(partial_wrapped_cross_val_score)(sklearn_pipeline=sklearn_pipeline) - for sklearn_pipeline in sklearn_pipeline_list[chunk_idx:chunk_idx + chunk_size]) - # update pbar - for val in tmp_result_scores: - result_score_list = self._update_val(val, result_score_list) + if self.use_dask: + import dask + + result_score_list = [ + partial_wrapped_cross_val_score(sklearn_pipeline=sklearn_pipeline) + for sklearn_pipeline in sklearn_pipeline_list + ] + + self.dask_graphs_ = result_score_list + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + result_score_list = list(dask.compute(*result_score_list)) + + self._update_pbar(len(result_score_list)) + + else: + chunk_size = min(cpu_count()*2, self.n_jobs*4) + + for chunk_idx in range(0, len(sklearn_pipeline_list), chunk_size): + self._stop_by_max_time_mins() + + parallel = Parallel(n_jobs=self.n_jobs, verbose=0, pre_dispatch='2*n_jobs') + tmp_result_scores = parallel( + delayed(partial_wrapped_cross_val_score)(sklearn_pipeline=sklearn_pipeline) + for sklearn_pipeline in sklearn_pipeline_list[chunk_idx:chunk_idx + chunk_size]) + # update pbar + for val in tmp_result_scores: + result_score_list = self._update_val(val, result_score_list) self._update_evaluated_individuals_(result_score_list, eval_individuals_str, operator_counts, stats_dicts) diff --git a/tpot/gp_deap.py b/tpot/gp_deap.py index 5a805bc7..31dd3289 100644 --- a/tpot/gp_deap.py +++ b/tpot/gp_deap.py @@ -23,6 +23,7 @@ """ +import dask import numpy as np from deap import tools, gp from inspect import isclass @@ -395,8 +396,10 @@ def mutNodeReplacement(individual, pset): @threading_timeoutable(default="Timeout") def _wrapped_cross_val_score(sklearn_pipeline, features, target, - cv, scoring_function, sample_weight=None, groups=None): + cv, scoring_function, sample_weight=None, + groups=None, use_dask=False): """Fit estimator and compute scores for a given dataset split. + Parameters ---------- sklearn_pipeline : pipeline object implementing 'fit' @@ -418,6 +421,8 @@ def _wrapped_cross_val_score(sklearn_pipeline, features, target, List of sample weights to balance (or un-balanace) the dataset target as needed groups: array-like {n_samples, }, optional Group labels for the samples used while splitting the dataset into train/test set + use_dask : bool, default False + Whether to use dask """ sample_weight_dict = set_sample_weight(sklearn_pipeline.steps, sample_weight) @@ -427,22 +432,50 @@ def _wrapped_cross_val_score(sklearn_pipeline, features, target, cv_iter = list(cv.split(features, target, groups)) scorer = check_scoring(sklearn_pipeline, scoring=scoring_function) - try: - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - scores = [_fit_and_score(estimator=clone(sklearn_pipeline), - X=features, - y=target, - scorer=scorer, - train=train, - test=test, - verbose=0, - parameters=None, - fit_params=sample_weight_dict) - for train, test in cv_iter] - CV_score = np.array(scores)[:, 0] - return np.nanmean(CV_score) - except TimeoutException: - return "Timeout" - except Exception as e: - return -float('inf') + if use_dask: + try: + import dask_ml.model_selection # noqa + import dask # noqa + from dask.delayed import Delayed + except ImportError: + msg = "'use_dask' requires the optional dask and dask-ml depedencies." + raise ImportError(msg) + + dsk, keys, n_splits = dask_ml.model_selection._search.build_graph( + estimator=sklearn_pipeline, + cv=cv, + scorer=scorer, + candidate_params=[{}], + X=features, + y=target, + groups=groups, + fit_params=sample_weight_dict, + refit=False, + error_score=float('-inf'), + ) + + cv_results = Delayed(keys[0], dsk) + scores = [cv_results['split{}_test_score'.format(i)] + for i in range(n_splits)] + CV_score = dask.delayed(np.array)(scores)[:, 0] + return dask.delayed(np.nanmean)(CV_score) + else: + try: + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + scores = [_fit_and_score(estimator=clone(sklearn_pipeline), + X=features, + y=target, + scorer=scorer, + train=train, + test=test, + verbose=0, + parameters=None, + fit_params=sample_weight_dict) + for train, test in cv_iter] + CV_score = np.array(scores)[:, 0] + return np.nanmean(CV_score) + except TimeoutException: + return "Timeout" + except Exception as e: + return -float('inf')