Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use dask.delayed within fit #730

Merged
merged 34 commits into from
Aug 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8cb1fae
[WIP] Use dask.delayed within fit
mrocklin Jul 15, 2018
2e1b373
add dask[delayed] to ci configurations
mrocklin Jul 16, 2018
51cf0cd
Fixup fail side
TomAugspurger Jul 25, 2018
d1d10a0
Fixed return shape
TomAugspurger Aug 6, 2018
7e7e68f
Reuse dask-ml
TomAugspurger Aug 7, 2018
d06edd6
configurable delayed
TomAugspurger Aug 7, 2018
6d370ec
typo
TomAugspurger Aug 7, 2018
8808c77
Merge remote-tracking branch 'upstream/development' into mrocklin-dask
TomAugspurger Aug 7, 2018
5d21024
toggle approach
TomAugspurger Aug 8, 2018
439ae5c
wip
TomAugspurger Aug 8, 2018
d9aca85
chunking
TomAugspurger Aug 8, 2018
b919bed
assign
TomAugspurger Aug 8, 2018
54011ff
push debugging code
TomAugspurger Aug 9, 2018
2050b2d
some cleanup, tests
TomAugspurger Aug 15, 2018
2577d08
some cleanup, tests
TomAugspurger Aug 15, 2018
36b2d23
Docs
TomAugspurger Aug 15, 2018
3ba6082
dependencies
TomAugspurger Aug 15, 2018
c378519
Bump dask-ml version
TomAugspurger Aug 15, 2018
2eb71dd
debugging CI
TomAugspurger Aug 16, 2018
4ca3b95
debugging CI
TomAugspurger Aug 16, 2018
ef325b4
debugging
TomAugspurger Aug 16, 2018
a3102ac
Handle training errors
TomAugspurger Aug 21, 2018
3144977
Try dask-ml master
TomAugspurger Aug 21, 2018
a80888f
test
TomAugspurger Aug 21, 2018
6a85646
Trigger CI
TomAugspurger Aug 21, 2018
22226f2
Trigger CI
TomAugspurger Aug 21, 2018
84b4474
remove pythonhashseed
TomAugspurger Aug 21, 2018
ea3a1bd
print debug
TomAugspurger Aug 22, 2018
d9b4d9a
Try single-threaded
TomAugspurger Aug 29, 2018
4342853
smaller
TomAugspurger Aug 29, 2018
d279253
Handle failure on master
TomAugspurger Aug 29, 2018
ac6b770
skip that test
TomAugspurger Aug 29, 2018
b3342fb
Install from git
TomAugspurger Aug 29, 2018
3d2fcd1
Doc / cleanup
TomAugspurger Aug 29, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ install:
- conda info -a
- conda create -q -n test-environment python=%PYTHON_VERSION% numpy scipy scikit-learn nose cython pandas
- activate test-environment
- pip install deap tqdm update_checker pypiwin32 stopit
- pip install deap tqdm update_checker pypiwin32 stopit dask[delayed] git+https://github.com/dask/dask-ml


test_script:
Expand Down
2 changes: 2 additions & 0 deletions ci/.travis_install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ pip install update_checker
pip install tqdm
pip install stopit
pip install xgboost
pip install dask[delayed]
pip install git+https://github.com/dask/dask-ml # TODO: Change to >=0.9 when released

if [[ "$COVERAGE" == "true" ]]; then
pip install coverage coveralls
Expand Down
41 changes: 41 additions & 0 deletions docs_sources/using.md
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,47 @@ rmtree(cachedir)

**Note: TPOT does NOT clean up memory caches if users set a custom directory path or Memory object. We recommend that you clean up the memory caches when you don't need it anymore.**

# Parallel Training

Internally, TPOT uses [joblib](http://joblib.readthedocs.io/) to fit estimators in parallel.
This is the same parallelization framework used by scikit-learn.

When you specify ``n_jobs``, TPOT will use ``n_jobs`` processes to fit models in parallel.
For large problems, you can distribute the work on a [Dask](http://dask.pydata.org/en/latest/) cluster.
There are two ways to achieve this.

First, you can specify the ``use_dask`` keyword when you create the TPOT estimator.

```python
estimator = TPOTEstimator(n_jobs=-1, use_dask=True
```

This will use use all the workers on your cluster to do the training, and use [Dask-ML's pipeline rewriting](https://dask-ml.readthedocs.io/en/latest/hyper-parameter-search.html#avoid-repeated-work) to avoid re-fitting estimators multiple times on the same set of data.
It will provide fine-grained diagnostics in the [distributed scheduler UI](https://distributed.readthedocs.io/en/latest/web.html).

Alternatively, Dask implements a joblib backend.
You can instruct TPOT to use the distribued backend during training by specifying a ``joblib.parallel_backend``:

```python
from sklearn.externals import joblib
import distributed.joblib
from dask.distributed import Client

# connect to the cluster
client = Client('schedueler-address')

# create the estimator normally
estimator = TPOTClassifier(n_jobs=-1)

# perform the fit in this context manager
with joblib.parallel_backend("dask"):
estimator.fit(X, y)
```

See [dask's distributed joblib integration](https://distributed.readthedocs.io/en/latest/joblib.html) for more.

We recommend using the `use_dask` keyword.

# Crash/freeze issue with n_jobs > 1 under OSX or Linux

TPOT supports parallel computing for speeding up the optimization process, but it may crash/freeze with n_jobs > 1 under OSX or Linux [as scikit-learn does](http://scikit-learn.org/stable/faq.html#why-do-i-sometime-get-a-crash-freeze-with-n-jobs-1-under-osx-or-linux), especially with large datasets.
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def calculate_version():
extras_require={
'xgboost': ['xgboost==0.6a2'],
'skrebate': ['skrebate>=0.3.4'],
'mdr': ['scikit-mdr>=0.4.4']
'mdr': ['scikit-mdr>=0.4.4'],
'dask': ['dask>=0.18.2',
'distributed>=1.22.1',
'dask-ml>=0.8.0'],
},
classifiers=[
'Intended Audience :: Science/Research',
Expand Down
10 changes: 10 additions & 0 deletions tests/driver_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,23 @@
import sys
from os import remove, path
from contextlib import contextmanager
from distutils.version import LooseVersion
try:
from StringIO import StringIO
except ImportError:
from io import StringIO

import nose
import numpy as np
import pandas as pd
import sklearn

from tpot.driver import positive_integer, float_range, _get_arg_parser, _print_args, _read_data_file, load_scoring_function, tpot_driver
from nose.tools import assert_raises, assert_equal, assert_in
from unittest import TestCase



@contextmanager
def captured_output():
new_out, new_err = StringIO(), StringIO()
Expand Down Expand Up @@ -169,6 +173,12 @@ def test_driver_4():

def test_driver_5():
"""Assert that the tpot_driver() in TPOT driver outputs normal result with exported python file and verbosity = 0."""

# Catch FutureWarning https://github.com/scikit-learn/scikit-learn/issues/11785
if (np.__version__ >= LooseVersion("1.15.0") and
sklearn.__version__ <= LooseVersion("0.20.0")):
raise nose.SkipTest("Warning raised by scikit-learn")

args_list = [
'tests/tests.csv',
'-is', ',',
Expand Down
78 changes: 78 additions & 0 deletions tests/test_dask_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Tests that ensure the dask-based fit matches.

https://github.com/DEAP/deap/issues/75
"""
import unittest

import nose
from sklearn.datasets import make_classification

from tpot import TPOTClassifier

try:
import dask # noqa
import dask_ml # noqa
except ImportError:
raise nose.SkipTest()


class TestDaskMatches(unittest.TestCase):

def test_dask_matches(self):
with dask.config.set(scheduler='single-threaded'):
for n_jobs in [-1]:
X, y = make_classification(random_state=0)
a = TPOTClassifier(
generations=2,
population_size=5,
cv=3,
random_state=0,
n_jobs=n_jobs,
use_dask=False,
)
b = TPOTClassifier(
generations=2,
population_size=5,
cv=3,
random_state=0,
n_jobs=n_jobs,
use_dask=True,
)
b.fit(X, y)
a.fit(X, y)

self.assertEqual(a.score(X, y), b.score(X, y))
self.assertEqual(a.pareto_front_fitted_pipelines_.keys(),
b.pareto_front_fitted_pipelines_.keys())
self.assertEqual(a.evaluated_individuals_,
b.evaluated_individuals_)

def test_handles_errors(self):
X, y = make_classification(n_samples=5)

tpot_config = {
'sklearn.neighbors.KNeighborsClassifier': {
'n_neighbors': [1, 100],
}
}

a = TPOTClassifier(
generations=2,
population_size=5,
cv=3,
random_state=0,
n_jobs=-1,
config_dict=tpot_config,
use_dask=False,
)
b = TPOTClassifier(
generations=2,
population_size=5,
cv=3,
random_state=0,
n_jobs=-1,
config_dict=tpot_config,
use_dask=True,
)
a.fit(X, y)
b.fit(X, y)
54 changes: 42 additions & 12 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
random_state=None, config_dict=None,
warm_start=False, memory=None,
periodic_checkpoint_folder=None, early_stop=None,
verbosity=0, disable_update_check=False):
verbosity=0, disable_update_check=False,
use_dask=False):
"""Set up the genetic programming algorithm for pipeline optimization.

Parameters
Expand Down Expand Up @@ -252,6 +253,14 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
A setting of 2 or higher will add a progress bar during the optimization procedure.
disable_update_check: bool, optional (default: False)
Flag indicating whether the TPOT version checker should be disabled.
use_dask : bool, default False
Whether to use Dask-ML's pipeline optimiziations. This avoid re-fitting
the same estimator on the same split of data multiple times. It
will also provide more detailed diagnostics when using Dask's
distributed scheduler.

See `avoid repeated work <https://dask-ml.readthedocs.io/en/latest/hyper-parameter-search.html#avoid-repeated-work>`__
for more.

Returns
-------
Expand Down Expand Up @@ -286,6 +295,7 @@ def __init__(self, generations=100, population_size=100, offspring_size=None,
self._last_optimized_pareto_front_n_gens = 0
self.memory = memory
self._memory = None # initial Memory setting for sklearn pipeline
self.use_dask = use_dask

# dont save periodic pipelines more often than this
self._output_best_pipeline_period_seconds = 30
Expand Down Expand Up @@ -1190,28 +1200,48 @@ def _evaluate_individuals(self, individuals, features, target, sample_weight=Non
scoring_function=self.scoring_function,
sample_weight=sample_weight,
groups=groups,
timeout=self.max_eval_time_seconds
timeout=self.max_eval_time_seconds,
use_dask=self.use_dask,
)

result_score_list = []
# Don't use parallelization if n_jobs==1
if self.n_jobs == 1:
if self.n_jobs == 1 and not self.use_dask:
for sklearn_pipeline in sklearn_pipeline_list:
self._stop_by_max_time_mins()
val = partial_wrapped_cross_val_score(sklearn_pipeline=sklearn_pipeline)
result_score_list = self._update_val(val, result_score_list)
else:
# chunk size for pbar update
# chunk size is min of cpu_count * 2 and n_jobs * 4
chunk_size = min(cpu_count()*2, self.n_jobs*4)
for chunk_idx in range(0, len(sklearn_pipeline_list), chunk_size):
self._stop_by_max_time_mins()
parallel = Parallel(n_jobs=self.n_jobs, verbose=0, pre_dispatch='2*n_jobs')
tmp_result_scores = parallel(delayed(partial_wrapped_cross_val_score)(sklearn_pipeline=sklearn_pipeline)
for sklearn_pipeline in sklearn_pipeline_list[chunk_idx:chunk_idx + chunk_size])
# update pbar
for val in tmp_result_scores:
result_score_list = self._update_val(val, result_score_list)
if self.use_dask:
import dask

result_score_list = [
partial_wrapped_cross_val_score(sklearn_pipeline=sklearn_pipeline)
for sklearn_pipeline in sklearn_pipeline_list
]

self.dask_graphs_ = result_score_list
with warnings.catch_warnings():
warnings.simplefilter('ignore')
result_score_list = list(dask.compute(*result_score_list))

self._update_pbar(len(result_score_list))

else:
chunk_size = min(cpu_count()*2, self.n_jobs*4)

for chunk_idx in range(0, len(sklearn_pipeline_list), chunk_size):
self._stop_by_max_time_mins()

parallel = Parallel(n_jobs=self.n_jobs, verbose=0, pre_dispatch='2*n_jobs')
tmp_result_scores = parallel(
delayed(partial_wrapped_cross_val_score)(sklearn_pipeline=sklearn_pipeline)
for sklearn_pipeline in sklearn_pipeline_list[chunk_idx:chunk_idx + chunk_size])
# update pbar
for val in tmp_result_scores:
result_score_list = self._update_val(val, result_score_list)

self._update_evaluated_individuals_(result_score_list, eval_individuals_str, operator_counts, stats_dicts)

Expand Down
73 changes: 53 additions & 20 deletions tpot/gp_deap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

"""

import dask
import numpy as np
from deap import tools, gp
from inspect import isclass
Expand Down Expand Up @@ -395,8 +396,10 @@ def mutNodeReplacement(individual, pset):

@threading_timeoutable(default="Timeout")
def _wrapped_cross_val_score(sklearn_pipeline, features, target,
cv, scoring_function, sample_weight=None, groups=None):
cv, scoring_function, sample_weight=None,
groups=None, use_dask=False):
"""Fit estimator and compute scores for a given dataset split.

Parameters
----------
sklearn_pipeline : pipeline object implementing 'fit'
Expand All @@ -418,6 +421,8 @@ def _wrapped_cross_val_score(sklearn_pipeline, features, target,
List of sample weights to balance (or un-balanace) the dataset target as needed
groups: array-like {n_samples, }, optional
Group labels for the samples used while splitting the dataset into train/test set
use_dask : bool, default False
Whether to use dask
"""
sample_weight_dict = set_sample_weight(sklearn_pipeline.steps, sample_weight)

Expand All @@ -427,22 +432,50 @@ def _wrapped_cross_val_score(sklearn_pipeline, features, target,
cv_iter = list(cv.split(features, target, groups))
scorer = check_scoring(sklearn_pipeline, scoring=scoring_function)

try:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
scores = [_fit_and_score(estimator=clone(sklearn_pipeline),
X=features,
y=target,
scorer=scorer,
train=train,
test=test,
verbose=0,
parameters=None,
fit_params=sample_weight_dict)
for train, test in cv_iter]
CV_score = np.array(scores)[:, 0]
return np.nanmean(CV_score)
except TimeoutException:
return "Timeout"
except Exception as e:
return -float('inf')
if use_dask:
try:
import dask_ml.model_selection # noqa
import dask # noqa
from dask.delayed import Delayed
except ImportError:
msg = "'use_dask' requires the optional dask and dask-ml depedencies."
raise ImportError(msg)

dsk, keys, n_splits = dask_ml.model_selection._search.build_graph(
estimator=sklearn_pipeline,
cv=cv,
scorer=scorer,
candidate_params=[{}],
X=features,
y=target,
groups=groups,
fit_params=sample_weight_dict,
refit=False,
error_score=float('-inf'),
)

cv_results = Delayed(keys[0], dsk)
scores = [cv_results['split{}_test_score'.format(i)]
for i in range(n_splits)]
CV_score = dask.delayed(np.array)(scores)[:, 0]
return dask.delayed(np.nanmean)(CV_score)
else:
try:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
scores = [_fit_and_score(estimator=clone(sklearn_pipeline),
X=features,
y=target,
scorer=scorer,
train=train,
test=test,
verbose=0,
parameters=None,
fit_params=sample_weight_dict)
for train, test in cv_iter]
CV_score = np.array(scores)[:, 0]
return np.nanmean(CV_score)
except TimeoutException:
return "Timeout"
except Exception as e:
return -float('inf')