Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci] [python] reduce unnecessary data loading in tests #3486

Merged
merged 16 commits into from
Oct 29, 2020
Merged
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ htmlcov/
.coverage.*
.cache
nosetests.xml
prof/
*.prof
coverage.xml
*,cover
.hypothesis/
Expand Down
Empty file.
4 changes: 3 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import numpy as np

from scipy import sparse
from sklearn.datasets import load_breast_cancer, dump_svmlight_file, load_svmlight_file
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split

from .utils import load_breast_cancer


class TestBasic(unittest.TestCase):

Expand Down
6 changes: 4 additions & 2 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import lightgbm as lgb
import numpy as np
from scipy.sparse import csr_matrix, isspmatrix_csr, isspmatrix_csc
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris, load_svmlight_file, make_multilabel_classification)
from sklearn.datasets import load_svmlight_file, make_multilabel_classification
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold

Expand All @@ -20,6 +19,8 @@
except ImportError:
import pickle

from .utils import load_boston, load_breast_cancer, load_digits, load_iris


decreasing_generator = itertools.count(0, -1)

Expand Down Expand Up @@ -2524,6 +2525,7 @@ def test_average_precision_metric(self):
sklearn_ap = average_precision_score(y, pred)
self.assertAlmostEqual(ap, sklearn_ap)
# test that average precision is 1 where model predicts perfectly
y = y.copy()
y[:] = 1
lgb_X = lgb.Dataset(X, label=y)
lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res)
Expand Down
3 changes: 2 additions & 1 deletion tests/python_package_test/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import lightgbm as lgb
from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

if MATPLOTLIB_INSTALLED:
Expand All @@ -12,6 +11,8 @@
if GRAPHVIZ_INSTALLED:
import graphviz

from .utils import load_breast_cancer


class TestBasic(unittest.TestCase):

Expand Down
6 changes: 3 additions & 3 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import numpy as np
from sklearn import __version__ as sk_version
from sklearn.base import clone
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris, load_linnerud, load_svmlight_file,
make_multilabel_classification)
from sklearn.datasets import load_svmlight_file, make_multilabel_classification
from sklearn.exceptions import SkipTestWarning
from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split
Expand All @@ -22,6 +20,8 @@
check_parameters_default_constructible)
from sklearn.utils.validation import check_is_fitted

from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud


decreasing_generator = itertools.count(0, -1)

Expand Down
45 changes: 45 additions & 0 deletions tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# coding: utf-8
import sklearn.datasets

try:
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
from functools import lru_cache
except ImportError:
import warnings
warnings.warn("Could not import functools.lru_cache", RuntimeWarning)

def lru_cache(maxsize=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please memoize this too

self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(*load_breast_cancer(return_X_y=True),

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the purpose? That's the only call to load_breast_cancer() in that module. So I think the caching would add a tiny bit of overhead for no benefit.

Copy link
Collaborator

@StrikerRUS StrikerRUS Oct 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it differ with other calls you've memoized? 5 calls of load_breast_cancer() in test_plotting.py is even more than 3 calls of the same function in test_basic.py, for example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is only 1 call in test_plotting.py

git grep load_breast_cancer tests/python_package_test/test_plotting.py

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, but the method in which this call is performed (setUp) is called before each test. So, actually we have 5 calls.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOOOOOOOOOO haha ok. I haven't used unittest.TestCase in a while, I forgot which one was a "run before every test" setup and which one was a "run exactly once, before any tests" one.

Ok yes I'll update this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in dfb0fd3

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Actually I think we should refactor this to "run exactly once, before any tests" (setUpClass()), but it is another issue, of course.

cache = {}

def _lru_wrapper(user_function):
def wrapper(*args, **kwargs):
arg_key = (args, tuple(kwargs.items()))
if arg_key not in cache:
cache[arg_key] = user_function(*args, **kwargs)
return cache[arg_key]
return wrapper
return _lru_wrapper


@lru_cache(maxsize=None)
def load_boston(**kwargs):
return sklearn.datasets.load_boston(**kwargs)


@lru_cache(maxsize=None)
def load_breast_cancer(**kwargs):
return sklearn.datasets.load_breast_cancer(**kwargs)


@lru_cache(maxsize=None)
def load_digits(**kwargs):
return sklearn.datasets.load_digits(**kwargs)


@lru_cache(maxsize=None)
def load_iris(**kwargs):
return sklearn.datasets.load_iris(**kwargs)


@lru_cache(maxsize=None)
def load_linnerud(**kwargs):
return sklearn.datasets.load_linnerud(**kwargs)