Skip to content

Commit

Permalink
[dask] factor dask-ml out of tests (fixes microsoft#3796)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Jan 25, 2021
1 parent 36322ce commit a1dd3d8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ if [[ $TASK == "swig" ]]; then
exit 0
fi

conda install -q -y -n $CONDA_ENV dask dask-ml distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy
conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy

# graphviz must come from conda-forge to avoid this on some linux distros:
# https://github.com/conda-forge/graphviz-feedstock/issues/18
Expand Down
17 changes: 13 additions & 4 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from scipy.stats import spearmanr
import scipy.sparse
from dask.array.utils import assert_eq
from dask_ml.metrics import accuracy_score, r2_score
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
from sklearn.datasets import make_blobs, make_regression
from sklearn.utils import check_random_state
Expand Down Expand Up @@ -124,6 +123,16 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size
return X, y, weights, dX, dy, dw


def _r2_score(dy_true, dy_pred):
numerator = ((dy_true - dy_pred) ** 2).sum(axis=0, dtype="f8")
denominator = ((dy_true - dy_pred.mean(axis=0)) ** 2).sum(axis=0, dtype="f8")
return (1 - numerator / denominator).compute()


def _accuracy_score(dy_true, dy_pred):
return da.average(dy_true == dy_pred).compute()


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier(output, centers, client, listen_port):
Expand All @@ -145,7 +154,7 @@ def test_classifier(output, centers, client, listen_port):
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict(dX)
p1_proba = dask_classifier.predict_proba(dX).compute()
s1 = accuracy_score(dy, p1)
s1 = _accuracy_score(dy, p1)
p1 = p1.compute()

local_classifier = lightgbm.LGBMClassifier(**params)
Expand Down Expand Up @@ -289,7 +298,7 @@ def test_regressor(output, client, listen_port):
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX)
if output != 'dataframe':
s1 = r2_score(dy, p1)
s1 = _r2_score(dy, p1)
p1 = p1.compute()

local_regressor = lightgbm.LGBMRegressor(**params)
Expand Down Expand Up @@ -391,7 +400,7 @@ def test_regressor_local_predict(client, listen_port):
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_regressor.predict(dX)
p2 = dask_regressor.to_local().predict(X)
s1 = r2_score(dy, p1)
s1 = _r2_score(dy, p1)
p1 = p1.compute()
s2 = dask_regressor.to_local().score(X, y)

Expand Down

0 comments on commit a1dd3d8

Please sign in to comment.