From 5871b008ef832c4a0d312801c1b9c191f60f491d Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 29 Jan 2021 16:42:56 -0600 Subject: [PATCH 01/18] starting on Dask client --- python-package/lightgbm/sklearn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 6a48f2c7b7d5..632082fe9eea 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -291,6 +291,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, if not SKLEARN_INSTALLED: raise LightGBMError('scikit-learn is required for lightgbm.sklearn') + # Dask estimators inherit from this and may pass an argument "client" + self.__client = kwargs.pop("client", None) + self.boosting_type = boosting_type self.objective = objective self.num_leaves = num_leaves From 7b95654e8e835274e6271d8152f52903247a7026 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 29 Jan 2021 17:27:36 -0600 Subject: [PATCH 02/18] more docs stuff --- python-package/lightgbm/dask.py | 59 +++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index c5f4049b0d5f..5d809f89a821 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -434,6 +434,18 @@ def _predict( class _DaskLGBMModel: + @property + def client(self) -> Client: + """Dask client + + This property can be passed in the constructor or directly assigned + like ``model.client = client``. + """ + if self.__client is None: + return default_client() + else: + return self.__client + def _fit( self, model_factory: Type[LGBMModel], @@ -446,13 +458,11 @@ def _fit( ) -> "_DaskLGBMModel": if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') - if client is None: - client = default_client() params = self.get_params(True) model = _train( - client=client, + client=self.client, data=X, label=y, params=params, @@ -489,7 +499,6 @@ def fit( X: _DaskMatrixLike, y: _DaskCollection, sample_weight: Optional[_DaskCollection] = None, - client: Optional[Client] = None, **kwargs: Any ) -> "DaskLGBMClassifier": """Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" @@ -498,17 +507,10 @@ def fit( X=X, y=y, sample_weight=sample_weight, - client=client, + client=self.client, **kwargs ) - _base_doc = LGBMClassifier.fit.__doc__ - _before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') - fit.__doc__ = (_before_init_score - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client.\n' - + ' ' * 8 + _init_score + _after_init_score) - def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" return _predict( @@ -542,6 +544,16 @@ def to_local(self) -> LGBMClassifier: return self._to_local(LGBMClassifier) +_base_doc = LGBMClassifier.__init__.__doc__ +_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') +DaskLGBMClassifier.__init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs +) + + class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRegressor.""" @@ -591,6 +603,16 @@ def to_local(self) -> LGBMRegressor: return self._to_local(LGBMRegressor) +_base_doc = LGBMRegressor.__init__.__doc__ +_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') +DaskLGBMRegressor.__init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs +) + + class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRanker.""" @@ -601,7 +623,6 @@ def fit( sample_weight: Optional[_DaskCollection] = None, init_score: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None, - client: Optional[Client] = None, **kwargs: Any ) -> "DaskLGBMRanker": """Docstring is inherited from the lightgbm.LGBMRanker.fit.""" @@ -614,7 +635,7 @@ def fit( y=y, sample_weight=sample_weight, group=group, - client=client, + client=self.client, **kwargs ) @@ -640,3 +661,13 @@ def to_local(self) -> LGBMRanker: Local underlying model. """ return self._to_local(LGBMRanker) + + +_base_doc = LGBMRanker.__init__.__doc__ +_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') +DaskLGBMRanker.__init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs +) From e9eaeb20e2d034768cb4ab939e640e8f320f59a5 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sat, 30 Jan 2021 23:10:43 -0600 Subject: [PATCH 03/18] fix pickling --- python-package/lightgbm/dask.py | 16 +- python-package/lightgbm/sklearn.py | 9 +- tests/python_package_test/test_dask.py | 214 +++++++++++++++++++++++-- 3 files changed, 225 insertions(+), 14 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 5d809f89a821..6f0e5e1e13b9 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -434,6 +434,9 @@ def _predict( class _DaskLGBMModel: + # self._client is set in the constructor of lightgbm.sklearn.LGBMModel + _client = None + @property def client(self) -> Client: """Dask client @@ -441,10 +444,14 @@ def client(self) -> Client: This property can be passed in the constructor or directly assigned like ``model.client = client``. """ - if self.__client is None: + if self._client is None: return default_client() else: - return self.__client + return self._client + + @client.setter + def client(self, client: Client) -> None: + self._client = client def _fit( self, @@ -472,9 +479,9 @@ def _fit( **kwargs ) + # at this point, self._client is still set self.set_params(**model.get_params()) self._copy_extra_params(model, self) - return self def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: @@ -488,7 +495,8 @@ def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union[" attributes = source.__dict__ extra_param_names = set(attributes.keys()).difference(params.keys()) for name in extra_param_names: - setattr(dest, name, attributes[name]) + if name != "_client": + setattr(dest, name, attributes[name]) class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 632082fe9eea..49dc121f6de5 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -292,7 +292,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, raise LightGBMError('scikit-learn is required for lightgbm.sklearn') # Dask estimators inherit from this and may pass an argument "client" - self.__client = kwargs.pop("client", None) + self._client = kwargs.pop("client", None) self.boosting_type = boosting_type self.objective = objective @@ -328,6 +328,13 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, self._n_classes = None self.set_params(**kwargs) + def __getstate__(self): + """Remove un-picklable attributes before serialization""" + client = self.__dict__.pop("_client", None) + out = copy.deepcopy(self.__dict__) + self._client = client + return out + def _more_tags(self): return { 'allow_nan': True, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 5db6c57006bb..98253509add5 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1,6 +1,8 @@ # coding: utf-8 """Tests for lightgbm.dask module""" +import joblib +import pickle import socket from itertools import groupby from os import getenv @@ -13,6 +15,7 @@ if not lgb.compat.DASK_INSTALLED: pytest.skip('Dask is not installed', allow_module_level=True) +import cloudpickle import dask.array as da import dask.dataframe as dd import numpy as np @@ -137,6 +140,32 @@ def _accuracy_score(dy_true, dy_pred): return da.average(dy_true == dy_pred).compute() +def _pickle(obj, filepath, serializer): + if serializer == 'pickle': + with open(filepath, 'wb') as f: + pickle.dump(obj, f) + elif serializer == 'joblib': + joblib.dump(obj, filepath) + elif serializer == 'cloudpickle': + with open(filepath, 'wb') as f: + cloudpickle.dump(obj, f) + else: + raise ValueError(f'unrecognized serializer type: {serializer}') + + +def _unpickle(filepath, serializer): + if serializer == 'pickle': + with open(filepath, 'rb') as f: + return pickle.load(f) + elif serializer == 'joblib': + return joblib.load(filepath) + elif serializer == 'cloudpickle': + with open(filepath, 'rb') as f: + return cloudpickle.load(f) + else: + raise ValueError(f'unrecognized serializer type: {serializer}') + + @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('centers', data_centers) def test_classifier(output, centers, client, listen_port): @@ -151,11 +180,12 @@ def test_classifier(output, centers, client, listen_port): "num_leaves": 10 } dask_classifier = lgb.DaskLGBMClassifier( + client=client, time_out=5, local_listen_port=listen_port, **params ) - dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) p1 = dask_classifier.predict(dX) p1_proba = dask_classifier.predict_proba(dX).compute() p1_local = dask_classifier.to_local().predict(X) @@ -193,12 +223,13 @@ def test_classifier_pred_contrib(output, centers, client, listen_port): "num_leaves": 10 } dask_classifier = lgb.DaskLGBMClassifier( + client=client, time_out=5, local_listen_port=listen_port, tree_learner='data', **params ) - dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) preds_with_contrib = dask_classifier.predict(dX, pred_contrib=True).compute() local_classifier = lgb.LGBMClassifier(**params) @@ -241,6 +272,7 @@ def test_training_does_not_fail_on_port_conflicts(client): s.bind(('127.0.0.1', 12400)) dask_classifier = lgb.DaskLGBMClassifier( + client=client, time_out=5, local_listen_port=12400, n_estimators=5, @@ -251,7 +283,6 @@ def test_training_does_not_fail_on_port_conflicts(client): X=dX, y=dy, sample_weight=dw, - client=client ) assert dask_classifier.booster_ @@ -270,12 +301,13 @@ def test_regressor(output, client, listen_port): "num_leaves": 10 } dask_regressor = lgb.DaskLGBMRegressor( + client=client, time_out=5, local_listen_port=listen_port, tree='data', **params ) - dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw) p1 = dask_regressor.predict(dX) if output != 'dataframe': s1 = _r2_score(dy, p1) @@ -313,12 +345,13 @@ def test_regressor_pred_contrib(output, client, listen_port): "num_leaves": 10 } dask_regressor = lgb.DaskLGBMRegressor( + client=client, time_out=5, local_listen_port=listen_port, tree_learner='data', **params ) - dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw) preds_with_contrib = dask_regressor.predict(dX, pred_contrib=True).compute() local_regressor = lgb.LGBMRegressor(**params) @@ -353,11 +386,12 @@ def test_regressor_quantile(output, client, listen_port, alpha): "num_leaves": 10 } dask_regressor = lgb.DaskLGBMRegressor( + client=client, local_listen_port=listen_port, tree_learner_type='data_parallel', **params ) - dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw) p1 = dask_regressor.predict(dX).compute() q1 = np.count_nonzero(y < p1) / y.shape[0] @@ -391,12 +425,13 @@ def test_ranker(output, client, listen_port, group): "min_child_samples": 1 } dask_ranker = lgb.DaskLGBMRanker( + client=client, time_out=5, local_listen_port=listen_port, tree_learner_type='data_parallel', **params ) - dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client) + dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg) rnkvec_dask = dask_ranker.predict(dX) rnkvec_dask = rnkvec_dask.compute() rnkvec_dask_local = dask_ranker.to_local().predict(X) @@ -415,6 +450,165 @@ def test_ranker(output, client, listen_port, group): client.close(timeout=CLIENT_CLOSE_TIMEOUT) +@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) +def test_training_works_if_client_not_provided_or_set_after_construction(task, listen_port, client): + if task == 'ranking': + _, _, _, _, dX, dy, _, dg = _create_ranking_data( + output='array', + group=None + ) + model_factory = lgb.DaskLGBMRanker + else: + _, _, _, dX, dy, _ = _create_data( + objective=task, + output='array', + ) + dg = None + if task == 'classification': + model_factory = lgb.DaskLGBMClassifier + elif task == 'regression': + model_factory = lgb.DaskLGBMRegressor + + params = { + "time_out": 5, + "local_listen_port": listen_port, + "n_estimators": 1, + "num_leaves": 2 + } + + # fit should work if client isn't provided + dask_model = model_factory(**params) + assert dask_model._client is None + assert dask_model.client == client + + dask_model.fit(dX, dy, group=dg) + assert dask_model.fitted_ + assert dask_model._client is None + assert dask_model.client == client + + preds = dask_model.predict(dX) + assert isinstance(preds, da.Array) + assert dask_model.fitted_ + assert dask_model._client is None + assert dask_model.client == client + + local_model = dask_model.to_local() + assert local_model._client is None + with pytest.raises(AttributeError): + local_model.client + + # should be able to set client after construction + dask_model = model_factory(**params) + dask_model.client = client + assert dask_model._client == client + assert dask_model.client == client + + dask_model.fit(dX, dy, group=dg) + assert dask_model.fitted_ + assert dask_model._client == client + assert dask_model.client == client + + preds = dask_model.predict(dX) + assert isinstance(preds, da.Array) + assert dask_model.fitted_ + assert dask_model._client == client + assert dask_model.client == client + + local_model = dask_model.to_local() + assert local_model._client is None + with pytest.raises(AttributeError): + local_model.client + + client.close(timeout=CLIENT_CLOSE_TIMEOUT) + + +@pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle']) +@pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) +@pytest.mark.parametrize('set_client', [True, False]) +def test_model_is_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, client, tmp_path): + if task == 'ranking': + _, _, _, _, dX, dy, _, dg = _create_ranking_data( + output='array', + group=None + ) + model_factory = lgb.DaskLGBMRanker + else: + _, _, _, dX, dy, _ = _create_data( + objective=task, + output='array', + ) + dg = None + if task == 'classification': + model_factory = lgb.DaskLGBMClassifier + elif task == 'regression': + model_factory = lgb.DaskLGBMRegressor + + params = { + "time_out": 5, + "local_listen_port": listen_port, + "n_estimators": 1, + "num_leaves": 2 + } + + if set_client: + params.update({"client": client}) + + # unfitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + tmp_file = str(tmp_path / "model-1.pkl") + dask_model = model_factory(**params) + if set_client: + assert dask_model._client == client + else: + assert dask_model._client is None + + assert dask_model.client == client + + _pickle( + obj=dask_model, + filepath=tmp_file, + serializer=serializer + ) + model_from_disk = _unpickle( + filepath=tmp_file, + serializer=serializer + ) + + if set_client: + assert dask_model._client == client + else: + assert dask_model._client is None + assert model_from_disk._client is None + assert model_from_disk.client == client + assert model_from_disk.get_params() == dask_model.get_params() + + # fitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + tmp_file2 = str(tmp_path / "model-2.pkl") + dask_model.fit(dX, dy, group=dg) + _pickle( + obj=dask_model, + filepath=tmp_file2, + serializer=serializer + ) + fitted_model_from_disk = _unpickle( + filepath=tmp_file2, + serializer=serializer + ) + + if set_client: + assert dask_model._client == client + else: + assert dask_model._client is None + assert isinstance(fitted_model_from_disk, model_factory) + assert fitted_model_from_disk._client is None + assert fitted_model_from_disk.client == client + assert fitted_model_from_disk.get_params() == dask_model.get_params() + preds_orig = dask_model.predict(dX).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX).compute() + assert_eq(preds_orig, preds_loaded_model) + + def test_find_open_port_works(): worker_ip = '127.0.0.1' with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -442,6 +636,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client): X = da.random.random((1e3, 10)) y = da.random.random((1e3, 1)) dask_regressor = lgb.DaskLGBMRegressor( + client=client, time_out=5, local_listen_port=1234, tree_learner='some-nonsense-value', @@ -449,7 +644,7 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client): num_leaves=2 ) with pytest.warns(UserWarning, match='Parameter tree_learner set to some-nonsense-value'): - dask_regressor = dask_regressor.fit(X, y, client=client) + dask_regressor = dask_regressor.fit(X, y) assert dask_regressor.fitted_ @@ -461,6 +656,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client): y = da.random.random((1e3, 1)) for tree_learner in ['feature_parallel', 'voting']: dask_regressor = lgb.DaskLGBMRegressor( + client=client, time_out=5, local_listen_port=1234, tree_learner=tree_learner, @@ -468,7 +664,7 @@ def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client): num_leaves=2 ) with pytest.warns(UserWarning, match='Support for tree_learner %s in lightgbm' % tree_learner): - dask_regressor = dask_regressor.fit(X, y, client=client) + dask_regressor = dask_regressor.fit(X, y) assert dask_regressor.fitted_ assert dask_regressor.get_params()['tree_learner'] == tree_learner From 2ed56d9356ab86c32ff513c8d35ec046b64104b2 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 01:06:11 -0600 Subject: [PATCH 04/18] just copy docstrings --- python-package/lightgbm/dask.py | 166 ++++++++++++++++++------- tests/python_package_test/test_dask.py | 7 ++ 2 files changed, 128 insertions(+), 45 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 6f0e5e1e13b9..4b1e7614d8ea 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -435,7 +435,7 @@ def _predict( class _DaskLGBMModel: # self._client is set in the constructor of lightgbm.sklearn.LGBMModel - _client = None + _client: Optional[Client] = None @property def client(self) -> Client: @@ -502,6 +502,46 @@ def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union[" class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): """Distributed version of lightgbm.LGBMClassifier.""" + def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + learning_rate=0.1, n_estimators=100, + subsample_for_bin=200000, objective=None, class_weight=None, + min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, + subsample=1., subsample_freq=0, colsample_bytree=1., + reg_alpha=0., reg_lambda=0., random_state=None, + n_jobs=-1, silent=True, importance_type='split', **kwargs): + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + silent=silent, + importance_type=importance_type, + **kwargs + ) + + _base_doc = LGBMClassifier.__init__.__doc__ + _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') + __init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs + ) + def fit( self, X: _DaskMatrixLike, @@ -552,19 +592,49 @@ def to_local(self) -> LGBMClassifier: return self._to_local(LGBMClassifier) -_base_doc = LGBMClassifier.__init__.__doc__ -_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') -DaskLGBMClassifier.__init__.__doc__ = ( - _before_kwargs - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' - + ' ' * 8 + _kwargs + _after_kwargs -) - - class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRegressor.""" + def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + learning_rate=0.1, n_estimators=100, + subsample_for_bin=200000, objective=None, class_weight=None, + min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, + subsample=1., subsample_freq=0, colsample_bytree=1., + reg_alpha=0., reg_lambda=0., random_state=None, + n_jobs=-1, silent=True, importance_type='split', **kwargs): + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + silent=silent, + importance_type=importance_type, + **kwargs + ) + + _base_doc = LGBMRegressor.__init__.__doc__ + _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') + __init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs + ) + def fit( self, X: _DaskMatrixLike, @@ -583,13 +653,6 @@ def fit( **kwargs ) - _base_doc = LGBMRegressor.fit.__doc__ - _before_init_score, _init_score, _after_init_score = _base_doc.partition('init_score :') - fit.__doc__ = (_before_init_score - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client.\n' - + ' ' * 8 + _init_score + _after_init_score) - def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" return _predict( @@ -611,19 +674,49 @@ def to_local(self) -> LGBMRegressor: return self._to_local(LGBMRegressor) -_base_doc = LGBMRegressor.__init__.__doc__ -_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') -DaskLGBMRegressor.__init__.__doc__ = ( - _before_kwargs - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' - + ' ' * 8 + _kwargs + _after_kwargs -) - - class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRanker.""" + def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + learning_rate=0.1, n_estimators=100, + subsample_for_bin=200000, objective=None, class_weight=None, + min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, + subsample=1., subsample_freq=0, colsample_bytree=1., + reg_alpha=0., reg_lambda=0., random_state=None, + n_jobs=-1, silent=True, importance_type='split', **kwargs): + super().__init__( + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + silent=silent, + importance_type=importance_type, + **kwargs + ) + + _base_doc = LGBMRanker.__init__.__doc__ + _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') + __init__.__doc__ = ( + _before_kwargs + + 'client : dask.distributed.Client or None, optional (default=None)\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 8 + _kwargs + _after_kwargs + ) + def fit( self, X: _DaskMatrixLike, @@ -647,13 +740,6 @@ def fit( **kwargs ) - _base_doc = LGBMRanker.fit.__doc__ - _before_eval_set, _eval_set, _after_eval_set = _base_doc.partition('eval_set :') - fit.__doc__ = (_before_eval_set - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client.\n' - + ' ' * 8 + _eval_set + _after_eval_set) - def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRanker.predict.""" return _predict(self.to_local(), X, **kwargs) @@ -669,13 +755,3 @@ def to_local(self) -> LGBMRanker: Local underlying model. """ return self._to_local(LGBMRanker) - - -_base_doc = LGBMRanker.__init__.__doc__ -_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') -DaskLGBMRanker.__init__.__doc__ = ( - _before_kwargs - + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' - + ' ' * 8 + _kwargs + _after_kwargs -) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 98253509add5..bfddbff3f016 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1,6 +1,7 @@ # coding: utf-8 """Tests for lightgbm.dask module""" +import inspect import joblib import pickle import socket @@ -688,3 +689,9 @@ def f(part): model_factory=lgb.LGBMClassifier ) assert 'foo' in str(info.value) + + +def test_dask_classes_and_sklearn_equivalents_have_identical_constructors(): + assert inspect.getfullargspec(lgb.LGBMClassifier.__init__) == inspect.getfullargspec(lgb.DaskLGBMClassifier.__init__) + assert inspect.getfullargspec(lgb.LGBMRegressor.__init__) == inspect.getfullargspec(lgb.DaskLGBMRegressor.__init__) + assert inspect.getfullargspec(lgb.LGBMRanker.__init__) == inspect.getfullargspec(lgb.DaskLGBMRanker.__init__) From b8e53eda68004090062f55d53004cfacdfd7cb9c Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 01:10:28 -0600 Subject: [PATCH 05/18] fit docs --- python-package/lightgbm/dask.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 4b1e7614d8ea..14d721835e24 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -559,6 +559,8 @@ def fit( **kwargs ) + fit.__doc__ = LGBMClassifier.fit.__doc__ + def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" return _predict( @@ -653,6 +655,8 @@ def fit( **kwargs ) + fit.__doc__ = LGBMRegressor.fit.__doc__ + def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" return _predict( @@ -740,6 +744,8 @@ def fit( **kwargs ) + fit.__doc__ = LGBMRanker.fit.__doc__ + def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRanker.predict.""" return _predict(self.to_local(), X, **kwargs) From f8059392dafd5a8e826b196597f0c88358b15d0d Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 01:57:09 -0600 Subject: [PATCH 06/18] switch test order --- tests/python_package_test/test_dask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index bfddbff3f016..a6f9376fb6bc 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -692,6 +692,6 @@ def f(part): def test_dask_classes_and_sklearn_equivalents_have_identical_constructors(): - assert inspect.getfullargspec(lgb.LGBMClassifier.__init__) == inspect.getfullargspec(lgb.DaskLGBMClassifier.__init__) - assert inspect.getfullargspec(lgb.LGBMRegressor.__init__) == inspect.getfullargspec(lgb.DaskLGBMRegressor.__init__) - assert inspect.getfullargspec(lgb.LGBMRanker.__init__) == inspect.getfullargspec(lgb.DaskLGBMRanker.__init__) + assert inspect.getfullargspec(lgb.DaskLGBMClassifier.__init__) == inspect.getfullargspec(lgb.LGBMClassifier.__init__) + assert inspect.getfullargspec(lgb.DaskLGBMRegressor.__init__) == inspect.getfullargspec(lgb.LGBMRegressor.__init__) + assert inspect.getfullargspec(lgb.DaskLGBMRanker.__init__) == inspect.getfullargspec(lgb.LGBMRanker.__init__) From eb2aee08a0810ef50a180a188c76ec1f9ad35e0d Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 02:12:16 -0600 Subject: [PATCH 07/18] linting --- python-package/lightgbm/dask.py | 4 ++-- python-package/lightgbm/sklearn.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 14d721835e24..35e401ea538c 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -439,7 +439,7 @@ class _DaskLGBMModel: @property def client(self) -> Client: - """Dask client + """Dask client. This property can be passed in the constructor or directly assigned like ``model.client = client``. @@ -479,9 +479,9 @@ def _fit( **kwargs ) - # at this point, self._client is still set self.set_params(**model.get_params()) self._copy_extra_params(model, self) + return self def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 49dc121f6de5..50e2c0d8c085 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -329,7 +329,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, self.set_params(**kwargs) def __getstate__(self): - """Remove un-picklable attributes before serialization""" + """Remove un-picklable attributes before serialization.""" client = self.__dict__.pop("_client", None) out = copy.deepcopy(self.__dict__) self._client = client From 344376bee1a7a3dbe97f7078a782d091d1084c30 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 14:42:20 -0600 Subject: [PATCH 08/18] use client kwarg --- .ci/test.sh | 2 +- python-package/lightgbm/dask.py | 27 +++++++++++++++++++++++---- python-package/lightgbm/sklearn.py | 13 ++----------- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index a150d33ffab0..da93d1b0f97f 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -90,7 +90,7 @@ if [[ $TASK == "swig" ]]; then exit 0 fi -conda install -q -y -n $CONDA_ENV dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy +conda install -q -y -n $CONDA_ENV cloudpickle dask distributed joblib matplotlib numpy pandas psutil pytest scikit-learn scipy # graphviz must come from conda-forge to avoid this on some linux distros: # https://github.com/conda-forge/graphviz-feedstock/issues/18 diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 35e401ea538c..c65e14187567 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -442,7 +442,7 @@ def client(self) -> Client: """Dask client. This property can be passed in the constructor or directly assigned - like ``model.client = client``. + like ``model.set_params(client=client)``. """ if self._client is None: return default_client() @@ -453,6 +453,13 @@ def client(self) -> Client: def client(self, client: Client) -> None: self._client = client + def _lgb_getstate(self) -> Dict[Any, Any]: + """Remove un-picklable attributes before serialization.""" + client = self.__dict__.pop("_client", None) + out = copy.deepcopy(self.__dict__) + self.set_params(client=client) + return out + def _fit( self, model_factory: Type[LGBMModel], @@ -508,7 +515,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', **kwargs): + n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -532,6 +539,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) + self.set_params(client=client) _base_doc = LGBMClassifier.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -542,6 +550,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + ' ' * 8 + _kwargs + _after_kwargs ) + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, @@ -603,7 +614,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', **kwargs): + n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -627,6 +638,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) + self.set_params(client=client) _base_doc = LGBMRegressor.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -637,6 +649,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + ' ' * 8 + _kwargs + _after_kwargs ) + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, @@ -687,7 +702,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', **kwargs): + n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -711,6 +726,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) + self.set_params(client=client) _base_doc = LGBMRanker.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -721,6 +737,9 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, + ' ' * 8 + _kwargs + _after_kwargs ) + def __getstate__(self) -> Dict[Any, Any]: + return self._lgb_getstate() + def fit( self, X: _DaskMatrixLike, diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 50e2c0d8c085..26c28a458512 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -291,9 +291,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, if not SKLEARN_INSTALLED: raise LightGBMError('scikit-learn is required for lightgbm.sklearn') - # Dask estimators inherit from this and may pass an argument "client" - self._client = kwargs.pop("client", None) - self.boosting_type = boosting_type self.objective = objective self.num_leaves = num_leaves @@ -328,13 +325,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, self._n_classes = None self.set_params(**kwargs) - def __getstate__(self): - """Remove un-picklable attributes before serialization.""" - client = self.__dict__.pop("_client", None) - out = copy.deepcopy(self.__dict__) - self._client = client - return out - def _more_tags(self): return { 'allow_nan': True, @@ -382,7 +372,8 @@ def set_params(self, **params): setattr(self, key, value) if hasattr(self, '_' + key): setattr(self, '_' + key, value) - self._other_params[key] = value + if key != "client": + self._other_params[key] = value return self def fit(self, X, y, From b0cf6c6b2ff71434dc2a16007090b03aff099e6a Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 21:09:49 -0600 Subject: [PATCH 09/18] remove inner set_params() --- python-package/lightgbm/dask.py | 23 ++++---- python-package/lightgbm/sklearn.py | 3 +- tests/python_package_test/test_dask.py | 78 ++++++++++++++++++++++---- 3 files changed, 79 insertions(+), 25 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index c65e14187567..f0f985ba94b5 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -434,7 +434,7 @@ def _predict( class _DaskLGBMModel: - # self._client is set in the constructor of lightgbm.sklearn.LGBMModel + # self._client is set in the constructor of classes that use this mixin _client: Optional[Client] = None @property @@ -455,9 +455,10 @@ def client(self, client: Client) -> None: def _lgb_getstate(self) -> Dict[Any, Any]: """Remove un-picklable attributes before serialization.""" + self._other_params.pop("client", None) client = self.__dict__.pop("_client", None) - out = copy.deepcopy(self.__dict__) - self.set_params(client=client) + out = deepcopy(self.__dict__) + self.client = client return out def _fit( @@ -467,13 +468,13 @@ def _fit( y: _DaskCollection, sample_weight: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None, - client: Optional[Client] = None, **kwargs: Any ) -> "_DaskLGBMModel": if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') params = self.get_params(True) + params.pop("client", None) model = _train( client=self.client, @@ -492,8 +493,11 @@ def _fit( return self def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: - model = model_factory(**self.get_params()) + params = self.get_params() + params.pop("client", None) + model = model_factory(**params) self._copy_extra_params(self, model) + model._other_params.pop("client", None) return model @staticmethod @@ -516,6 +520,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): + self.client = client super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -539,7 +544,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) - self.set_params(client=client) _base_doc = LGBMClassifier.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -566,7 +570,6 @@ def fit( X=X, y=y, sample_weight=sample_weight, - client=self.client, **kwargs ) @@ -615,6 +618,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): + self.client = client super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -638,7 +642,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) - self.set_params(client=client) _base_doc = LGBMRegressor.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -666,7 +669,6 @@ def fit( X=X, y=y, sample_weight=sample_weight, - client=client, **kwargs ) @@ -703,6 +705,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, subsample=1., subsample_freq=0, colsample_bytree=1., reg_alpha=0., reg_lambda=0., random_state=None, n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): + self.client = client super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -726,7 +729,6 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, importance_type=importance_type, **kwargs ) - self.set_params(client=client) _base_doc = LGBMRanker.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') @@ -759,7 +761,6 @@ def fit( y=y, sample_weight=sample_weight, group=group, - client=self.client, **kwargs ) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 26c28a458512..6a48f2c7b7d5 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -372,8 +372,7 @@ def set_params(self, **params): setattr(self, key, value) if hasattr(self, '_' + key): setattr(self, '_' + key, value) - if key != "client": - self._other_params[key] = value + self._other_params[key] = value return self def fit(self, X, y, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index a6f9376fb6bc..bf3bdb3ea5f6 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -494,13 +494,13 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l assert dask_model.client == client local_model = dask_model.to_local() - assert local_model._client is None + assert getattr(local_model, "_client", None) is None with pytest.raises(AttributeError): local_model.client # should be able to set client after construction dask_model = model_factory(**params) - dask_model.client = client + dask_model.set_params(client=client) assert dask_model._client == client assert dask_model.client == client @@ -516,7 +516,7 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l assert dask_model.client == client local_model = dask_model.to_local() - assert local_model._client is None + assert getattr(local_model, "_client", None) is None with pytest.raises(AttributeError): local_model.client @@ -526,15 +526,18 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l @pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle']) @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) @pytest.mark.parametrize('set_client', [True, False]) -def test_model_is_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, client, tmp_path): +# @pytest.mark.parametrize('serializer', ['pickle']) +# @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) +# @pytest.mark.parametrize('set_client', [True]) +def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, client, tmp_path): if task == 'ranking': - _, _, _, _, dX, dy, _, dg = _create_ranking_data( + X, _, _, _, dX, dy, _, dg = _create_ranking_data( output='array', group=None ) model_factory = lgb.DaskLGBMRanker else: - _, _, _, dX, dy, _ = _create_data( + X, _, _, dX, dy, _ = _create_data( objective=task, output='array', ) @@ -556,15 +559,18 @@ def test_model_is_picklable_whether_or_not_client_set_explicitly(serializer, tas # unfitted model should survive pickling round trip, and pickling # shouldn't have side effects on the model object - tmp_file = str(tmp_path / "model-1.pkl") dask_model = model_factory(**params) + local_model = dask_model.to_local() if set_client: assert dask_model._client == client else: assert dask_model._client is None assert dask_model.client == client + assert "client" not in local_model.get_params() + assert getattr(local_model, "client", None) is None + tmp_file = str(tmp_path / "model-1.pkl") _pickle( obj=dask_model, filepath=tmp_file, @@ -575,6 +581,17 @@ def test_model_is_picklable_whether_or_not_client_set_explicitly(serializer, tas serializer=serializer ) + local_tmp_file = str(tmp_path / "local-model-1.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file, + serializer=serializer + ) + local_model_from_disk = _unpickle( + filepath=local_tmp_file, + serializer=serializer + ) + if set_client: assert dask_model._client == client else: @@ -582,11 +599,17 @@ def test_model_is_picklable_whether_or_not_client_set_explicitly(serializer, tas assert model_from_disk._client is None assert model_from_disk.client == client assert model_from_disk.get_params() == dask_model.get_params() + assert local_model_from_disk.get_params() == local_model.get_params() # fitted model should survive pickling round trip, and pickling # shouldn't have side effects on the model object - tmp_file2 = str(tmp_path / "model-2.pkl") dask_model.fit(dX, dy, group=dg) + local_model = dask_model.to_local() + + assert "client" not in local_model.get_params() + assert getattr(local_model, "client", None) is None + + tmp_file2 = str(tmp_path / "model-2.pkl") _pickle( obj=dask_model, filepath=tmp_file2, @@ -597,6 +620,17 @@ def test_model_is_picklable_whether_or_not_client_set_explicitly(serializer, tas serializer=serializer ) + local_tmp_file2 = str(tmp_path / "local-model-2.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file2, + serializer=serializer + ) + local_fitted_model_from_disk = _unpickle( + filepath=local_tmp_file2, + serializer=serializer + ) + if set_client: assert dask_model._client == client else: @@ -605,10 +639,15 @@ def test_model_is_picklable_whether_or_not_client_set_explicitly(serializer, tas assert fitted_model_from_disk._client is None assert fitted_model_from_disk.client == client assert fitted_model_from_disk.get_params() == dask_model.get_params() + preds_orig = dask_model.predict(dX).compute() preds_loaded_model = fitted_model_from_disk.predict(dX).compute() assert_eq(preds_orig, preds_loaded_model) + preds_orig_local = local_model.predict(X) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X) + assert_eq(preds_orig_local, preds_loaded_model_local) + def test_find_open_port_works(): worker_ip = '127.0.0.1' @@ -691,7 +730,22 @@ def f(part): assert 'foo' in str(info.value) -def test_dask_classes_and_sklearn_equivalents_have_identical_constructors(): - assert inspect.getfullargspec(lgb.DaskLGBMClassifier.__init__) == inspect.getfullargspec(lgb.LGBMClassifier.__init__) - assert inspect.getfullargspec(lgb.DaskLGBMRegressor.__init__) == inspect.getfullargspec(lgb.LGBMRegressor.__init__) - assert inspect.getfullargspec(lgb.DaskLGBMRanker.__init__) == inspect.getfullargspec(lgb.LGBMRanker.__init__) +def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(): + def _compare_spec(dask_cls, sklearn_cls): + dask_spec = inspect.getfullargspec(dask_cls) + sklearn_spec = inspect.getfullargspec(sklearn_cls) + assert dask_spec.varargs == sklearn_spec.varargs + assert dask_spec.varkw == sklearn_spec.varkw + assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs + assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults + assert dask_spec.annotations == sklearn_spec.annotations + + # "client" should be the only different, and the final argument + assert dask_spec.args[:-1] == sklearn_spec.args + assert dask_spec.defaults[:-1] == sklearn_spec.defaults + assert dask_spec.args[-1] == 'client' + assert dask_spec.defaults[-1] is None + + _compare_spec(lgb.DaskLGBMClassifier, lgb.LGBMClassifier) + _compare_spec(lgb.DaskLGBMRegressor, lgb.LGBMRegressor) + _compare_spec(lgb.DaskLGBMRanker, lgb.LGBMRanker) From 56eb582754c10601a0256921c37ca3a92ffa25fd Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 21:25:14 -0600 Subject: [PATCH 10/18] add type hints --- python-package/lightgbm/dask.py | 99 +++++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 22 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index f0f985ba94b5..be57347dee26 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -9,11 +9,12 @@ import socket from collections import defaultdict from copy import deepcopy -from typing import Any, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from urllib.parse import urlparse import numpy as np import scipy.sparse as ss +from numpy.random import RandomState from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, @@ -513,13 +514,31 @@ def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union[" class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): """Distributed version of lightgbm.LGBMClassifier.""" - def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, - learning_rate=0.1, n_estimators=100, - subsample_for_bin=200000, objective=None, class_weight=None, - min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, - subsample=1., subsample_freq=0, colsample_bytree=1., - reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): + def __init__( + self, + boosting_type: str = 'gbdt', + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[Callable, str]] = None, + class_weight: Optional[Union[dict, str]] = None, + min_split_gain: float = 0., + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1., + subsample_freq: int = 0, + colsample_bytree: float = 1., + reg_alpha: float = 0., + reg_lambda: float = 0., + random_state: Optional[Union[int, RandomState]] = None, + n_jobs: int = -1, + silent: bool = True, + importance_type: str = 'split', + client: Optional[Client] = None, + **kwargs: Any + ): self.client = client super().__init__( boosting_type=boosting_type, @@ -611,13 +630,31 @@ def to_local(self) -> LGBMClassifier: class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRegressor.""" - def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, - learning_rate=0.1, n_estimators=100, - subsample_for_bin=200000, objective=None, class_weight=None, - min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, - subsample=1., subsample_freq=0, colsample_bytree=1., - reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): + def __init__( + self, + boosting_type: str = 'gbdt', + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[Callable, str]] = None, + class_weight: Optional[Union[dict, str]] = None, + min_split_gain: float = 0., + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1., + subsample_freq: int = 0, + colsample_bytree: float = 1., + reg_alpha: float = 0., + reg_lambda: float = 0., + random_state: Optional[Union[int, RandomState]] = None, + n_jobs: int = -1, + silent: bool = True, + importance_type: str = 'split', + client: Optional[Client] = None, + **kwargs: Any + ): self.client = client super().__init__( boosting_type=boosting_type, @@ -698,13 +735,31 @@ def to_local(self) -> LGBMRegressor: class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRanker.""" - def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1, - learning_rate=0.1, n_estimators=100, - subsample_for_bin=200000, objective=None, class_weight=None, - min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, - subsample=1., subsample_freq=0, colsample_bytree=1., - reg_alpha=0., reg_lambda=0., random_state=None, - n_jobs=-1, silent=True, importance_type='split', client=None, **kwargs): + def __init__( + self, + boosting_type: str = 'gbdt', + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[Callable, str]] = None, + class_weight: Optional[Union[dict, str]] = None, + min_split_gain: float = 0., + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1., + subsample_freq: int = 0, + colsample_bytree: float = 1., + reg_alpha: float = 0., + reg_lambda: float = 0., + random_state: Optional[Union[int, RandomState]] = None, + n_jobs: int = -1, + silent: bool = True, + importance_type: str = 'split', + client: Optional[Client] = None, + **kwargs: Any + ): self.client = client super().__init__( boosting_type=boosting_type, From 4a4213396d2a68084469505ac3cdb124847b59d2 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 22:20:42 -0600 Subject: [PATCH 11/18] fix type hints --- python-package/lightgbm/dask.py | 11 +++++++---- tests/python_package_test/test_dask.py | 1 - 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index be57347dee26..433d5d4904c0 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -14,7 +14,7 @@ import numpy as np import scipy.sparse as ss -from numpy.random import RandomState +#from np.random import RandomState from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, @@ -532,13 +532,14 @@ def __init__( colsample_bytree: float = 1., reg_alpha: float = 0., reg_lambda: float = 0., - random_state: Optional[Union[int, RandomState]] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, n_jobs: int = -1, silent: bool = True, importance_type: str = 'split', client: Optional[Client] = None, **kwargs: Any ): + """Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" self.client = client super().__init__( boosting_type=boosting_type, @@ -648,13 +649,14 @@ def __init__( colsample_bytree: float = 1., reg_alpha: float = 0., reg_lambda: float = 0., - random_state: Optional[Union[int, RandomState]] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, n_jobs: int = -1, silent: bool = True, importance_type: str = 'split', client: Optional[Client] = None, **kwargs: Any ): + """Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" self.client = client super().__init__( boosting_type=boosting_type, @@ -753,13 +755,14 @@ def __init__( colsample_bytree: float = 1., reg_alpha: float = 0., reg_lambda: float = 0., - random_state: Optional[Union[int, RandomState]] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, n_jobs: int = -1, silent: bool = True, importance_type: str = 'split', client: Optional[Client] = None, **kwargs: Any ): + """Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" self.client = client super().__init__( boosting_type=boosting_type, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index bf3bdb3ea5f6..591cd34808eb 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -738,7 +738,6 @@ def _compare_spec(dask_cls, sklearn_cls): assert dask_spec.varkw == sklearn_spec.varkw assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults - assert dask_spec.annotations == sklearn_spec.annotations # "client" should be the only different, and the final argument assert dask_spec.args[:-1] == sklearn_spec.args From 87f76aa1d043682a735e417545321dce2e1ba32b Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 22:23:51 -0600 Subject: [PATCH 12/18] remove commented code --- python-package/lightgbm/dask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 433d5d4904c0..e094491b99b0 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -14,7 +14,6 @@ import numpy as np import scipy.sparse as ss -#from np.random import RandomState from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, From 80dc6b9cab1ec5d833e8e86792128f8c287953ae Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sun, 31 Jan 2021 23:02:21 -0600 Subject: [PATCH 13/18] reorder --- python-package/lightgbm/dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index e094491b99b0..17719b5666f3 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -455,9 +455,9 @@ def client(self, client: Client) -> None: def _lgb_getstate(self) -> Dict[Any, Any]: """Remove un-picklable attributes before serialization.""" - self._other_params.pop("client", None) - client = self.__dict__.pop("_client", None) out = deepcopy(self.__dict__) + client = out.pop("_client", None) + self._other_params.pop("client", None) self.client = client return out From af269b56f88b4d2d0f9179dee471c869aa9713b7 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 1 Feb 2021 10:22:53 -0600 Subject: [PATCH 14/18] fix tests, add client_ property --- python-package/lightgbm/dask.py | 22 +++++----- tests/python_package_test/test_dask.py | 61 ++++++++++++++++++++------ 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 17719b5666f3..9833c724356b 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -438,26 +438,25 @@ class _DaskLGBMModel: _client: Optional[Client] = None @property - def client(self) -> Client: + def client_(self) -> Client: """Dask client. - This property can be passed in the constructor or directly assigned - like ``model.set_params(client=client)``. + This property can be passed in the constructor or updated + with ``model.set_params(client=client)``. """ if self._client is None: return default_client() else: return self._client - @client.setter - def client(self, client: Client) -> None: - self._client = client - def _lgb_getstate(self) -> Dict[Any, Any]: """Remove un-picklable attributes before serialization.""" - out = deepcopy(self.__dict__) - client = out.pop("_client", None) + client = self.__dict__.pop("client", None) + self.__dict__.pop("_client", None) self._other_params.pop("client", None) + out = deepcopy(self.__dict__) + out.update({"_client": None, "client": None}) + self._client = client self.client = client return out @@ -477,7 +476,7 @@ def _fit( params.pop("client", None) model = _train( - client=self.client, + client=self.client_, data=X, label=y, params=params, @@ -539,6 +538,7 @@ def __init__( **kwargs: Any ): """Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" + self._client = client self.client = client super().__init__( boosting_type=boosting_type, @@ -656,6 +656,7 @@ def __init__( **kwargs: Any ): """Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" + self._client = client self.client = client super().__init__( boosting_type=boosting_type, @@ -762,6 +763,7 @@ def __init__( **kwargs: Any ): """Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" + self._client = client self.client = client super().__init__( boosting_type=boosting_type, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 591cd34808eb..1cd5489140e1 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -480,45 +480,54 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l # fit should work if client isn't provided dask_model = model_factory(**params) assert dask_model._client is None - assert dask_model.client == client + assert dask_model.client is None + assert dask_model.client_ == client dask_model.fit(dX, dy, group=dg) assert dask_model.fitted_ assert dask_model._client is None - assert dask_model.client == client + assert dask_model.client is None + assert dask_model.client_ == client preds = dask_model.predict(dX) assert isinstance(preds, da.Array) assert dask_model.fitted_ assert dask_model._client is None - assert dask_model.client == client + assert dask_model.client is None + assert dask_model.client_ == client local_model = dask_model.to_local() - assert getattr(local_model, "_client", None) is None with pytest.raises(AttributeError): + local_model._client local_model.client + local_model.client_ # should be able to set client after construction dask_model = model_factory(**params) dask_model.set_params(client=client) assert dask_model._client == client assert dask_model.client == client + assert dask_model.client_ == client dask_model.fit(dX, dy, group=dg) assert dask_model.fitted_ assert dask_model._client == client assert dask_model.client == client + assert dask_model.client_ == client preds = dask_model.predict(dX) assert isinstance(preds, da.Array) assert dask_model.fitted_ assert dask_model._client == client assert dask_model.client == client + assert dask_model.client_ == client local_model = dask_model.to_local() assert getattr(local_model, "_client", None) is None with pytest.raises(AttributeError): + local_model._client local_model.client + local_model.client_ client.close(timeout=CLIENT_CLOSE_TIMEOUT) @@ -526,9 +535,6 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l @pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle']) @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) @pytest.mark.parametrize('set_client', [True, False]) -# @pytest.mark.parametrize('serializer', ['pickle']) -# @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) -# @pytest.mark.parametrize('set_client', [True]) def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, client, tmp_path): if task == 'ranking': X, _, _, _, dX, dy, _, dg = _create_ranking_data( @@ -563,10 +569,12 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici local_model = dask_model.to_local() if set_client: assert dask_model._client == client + assert dask_model.client == client else: assert dask_model._client is None + assert dask_model.client is None - assert dask_model.client == client + assert dask_model.client_ == client assert "client" not in local_model.get_params() assert getattr(local_model, "client", None) is None @@ -594,11 +602,22 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici if set_client: assert dask_model._client == client + assert dask_model.client == client else: assert dask_model._client is None + assert dask_model.client is None assert model_from_disk._client is None - assert model_from_disk.client == client - assert model_from_disk.get_params() == dask_model.get_params() + assert model_from_disk.client is None + assert model_from_disk.client_ == client + # client will always be None after unpickling + if set_client: + from_disk_params = model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert model_from_disk.get_params() == dask_model.get_params() assert local_model_from_disk.get_params() == local_model.get_params() # fitted model should survive pickling round trip, and pickling @@ -607,7 +626,10 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici local_model = dask_model.to_local() assert "client" not in local_model.get_params() - assert getattr(local_model, "client", None) is None + with pytest.raises(AttributeError): + local_model._client + local_model.client + local_model.client_ tmp_file2 = str(tmp_path / "model-2.pkl") _pickle( @@ -633,12 +655,25 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici if set_client: assert dask_model._client == client + assert dask_model.client == client else: assert dask_model._client is None + assert dask_model.client is None assert isinstance(fitted_model_from_disk, model_factory) assert fitted_model_from_disk._client is None - assert fitted_model_from_disk.client == client - assert fitted_model_from_disk.get_params() == dask_model.get_params() + assert fitted_model_from_disk.client is None + assert fitted_model_from_disk.client_ == client + + # client will always be None after unpickling + if set_client: + from_disk_params = fitted_model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert fitted_model_from_disk.get_params() == dask_model.get_params() + assert local_fitted_model_from_disk.get_params() == local_model.get_params() preds_orig = dask_model.predict(dX).compute() preds_loaded_model = fitted_model_from_disk.predict(dX).compute() From 8cd61014fc695c2f38b576c0722cda2a636adb1e Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 2 Feb 2021 13:10:23 -0600 Subject: [PATCH 15/18] Apply suggestions from code review Co-authored-by: Nikita Titov --- tests/python_package_test/test_dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 16b929700a95..955c6e8ebf94 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -151,7 +151,7 @@ def _pickle(obj, filepath, serializer): with open(filepath, 'wb') as f: cloudpickle.dump(obj, f) else: - raise ValueError(f'unrecognized serializer type: {serializer}') + raise ValueError(f'Unrecognized serializer type: {serializer}') def _unpickle(filepath, serializer): @@ -164,7 +164,7 @@ def _unpickle(filepath, serializer): with open(filepath, 'rb') as f: return cloudpickle.load(f) else: - raise ValueError(f'unrecognized serializer type: {serializer}') + raise ValueError(f'Unrecognized serializer type: {serializer}') @pytest.mark.parametrize('output', data_output) From 555a57a47345e6b5b1a205b88a4c794bbf096e50 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 2 Feb 2021 16:27:57 -0600 Subject: [PATCH 16/18] fix tests --- python-package/lightgbm/dask.py | 37 ++- tests/python_package_test/test_dask.py | 365 ++++++++++++++----------- 2 files changed, 238 insertions(+), 164 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 9833c724356b..0ba89df2cee0 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -17,7 +17,7 @@ from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, - SKLEARN_INSTALLED, + SKLEARN_INSTALLED, LGBMNotFittedError, DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait) from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker @@ -27,6 +27,25 @@ _PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] +def _get_dask_client(client: Optional[Client]) -> Client: + """Choose a Dask client to use + + Parameters + ---------- + client : dask.distributed.Client or None + Dask client. + + Returns + ------- + client : dask.distributed.Client + A Dask client. + """ + if client is None: + return default_client() + else: + return client + + def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: """Find an open port. @@ -444,10 +463,10 @@ def client_(self) -> Client: This property can be passed in the constructor or updated with ``model.set_params(client=client)``. """ - if self._client is None: - return default_client() - else: - return self._client + if not getattr(self, "fitted_", False): + raise LGBMNotFittedError('Cannot access property client_ before calling fit().') + + return _get_dask_client(client=self.client) def _lgb_getstate(self) -> Dict[Any, Any]: """Remove un-picklable attributes before serialization.""" @@ -476,7 +495,7 @@ def _fit( params.pop("client", None) model = _train( - client=self.client_, + client=_get_dask_client(self.client), data=X, label=y, params=params, @@ -569,7 +588,7 @@ def __init__( __init__.__doc__ = ( _before_kwargs + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + ' ' * 8 + _kwargs + _after_kwargs ) @@ -687,7 +706,7 @@ def __init__( __init__.__doc__ = ( _before_kwargs + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + ' ' * 8 + _kwargs + _after_kwargs ) @@ -794,7 +813,7 @@ def __init__( __init__.__doc__ = ( _before_kwargs + 'client : dask.distributed.Client or None, optional (default=None)\n' - + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. This client will not be saved if the model object is pickled.\n' + + ' ' * 12 + 'Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.\n' + ' ' * 8 + _kwargs + _after_kwargs ) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 955c6e8ebf94..e004c9cf933d 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -23,7 +23,7 @@ import pandas as pd from scipy.stats import spearmanr from dask.array.utils import assert_eq -from dask.distributed import wait +from dask.distributed import default_client, Client, LocalCluster, wait from distributed.utils_test import client, cluster_fixture, gen_cluster, loop from scipy.sparse import csr_matrix from sklearn.datasets import make_blobs, make_regression @@ -486,11 +486,12 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l "num_leaves": 2 } - # fit should work if client isn't provided + # should be able to use the class without specifying a client dask_model = model_factory(**params) assert dask_model._client is None assert dask_model.client is None - assert dask_model.client_ == client + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ dask_model.fit(dX, dy, group=dg) assert dask_model.fitted_ @@ -516,7 +517,9 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l dask_model.set_params(client=client) assert dask_model._client == client assert dask_model.client == client - assert dask_model.client_ == client + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ dask_model.fit(dX, dy, group=dg) assert dask_model.fitted_ @@ -544,153 +547,199 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l @pytest.mark.parametrize('serializer', ['pickle', 'joblib', 'cloudpickle']) @pytest.mark.parametrize('task', ['classification', 'regression', 'ranking']) @pytest.mark.parametrize('set_client', [True, False]) -def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, client, tmp_path): - if task == 'ranking': - X, _, _, _, dX, dy, _, dg = _create_ranking_data( - output='array', - group=None - ) - model_factory = lgb.DaskLGBMRanker - else: - X, _, _, dX, dy, _ = _create_data( - objective=task, - output='array', - ) - dg = None - if task == 'classification': - model_factory = lgb.DaskLGBMClassifier - elif task == 'regression': - model_factory = lgb.DaskLGBMRegressor - - params = { - "time_out": 5, - "local_listen_port": listen_port, - "n_estimators": 1, - "num_leaves": 2 - } - - if set_client: - params.update({"client": client}) - - # unfitted model should survive pickling round trip, and pickling - # shouldn't have side effects on the model object - dask_model = model_factory(**params) - local_model = dask_model.to_local() - if set_client: - assert dask_model._client == client - assert dask_model.client == client - else: - assert dask_model._client is None - assert dask_model.client is None - - assert dask_model.client_ == client - assert "client" not in local_model.get_params() - assert getattr(local_model, "client", None) is None - - tmp_file = str(tmp_path / "model-1.pkl") - _pickle( - obj=dask_model, - filepath=tmp_file, - serializer=serializer - ) - model_from_disk = _unpickle( - filepath=tmp_file, - serializer=serializer - ) - - local_tmp_file = str(tmp_path / "local-model-1.pkl") - _pickle( - obj=local_model, - filepath=local_tmp_file, - serializer=serializer - ) - local_model_from_disk = _unpickle( - filepath=local_tmp_file, - serializer=serializer - ) - - if set_client: - assert dask_model._client == client - assert dask_model.client == client - else: - assert dask_model._client is None - assert dask_model.client is None - assert model_from_disk._client is None - assert model_from_disk.client is None - assert model_from_disk.client_ == client - # client will always be None after unpickling - if set_client: - from_disk_params = model_from_disk.get_params() - from_disk_params.pop("client", None) - dask_params = dask_model.get_params() - dask_params.pop("client", None) - assert from_disk_params == dask_params - else: - assert model_from_disk.get_params() == dask_model.get_params() - assert local_model_from_disk.get_params() == local_model.get_params() - - # fitted model should survive pickling round trip, and pickling - # shouldn't have side effects on the model object - dask_model.fit(dX, dy, group=dg) - local_model = dask_model.to_local() - - assert "client" not in local_model.get_params() - with pytest.raises(AttributeError): - local_model._client - local_model.client - local_model.client_ - - tmp_file2 = str(tmp_path / "model-2.pkl") - _pickle( - obj=dask_model, - filepath=tmp_file2, - serializer=serializer - ) - fitted_model_from_disk = _unpickle( - filepath=tmp_file2, - serializer=serializer - ) - - local_tmp_file2 = str(tmp_path / "local-model-2.pkl") - _pickle( - obj=local_model, - filepath=local_tmp_file2, - serializer=serializer - ) - local_fitted_model_from_disk = _unpickle( - filepath=local_tmp_file2, - serializer=serializer - ) - - if set_client: - assert dask_model._client == client - assert dask_model.client == client - else: - assert dask_model._client is None - assert dask_model.client is None - assert isinstance(fitted_model_from_disk, model_factory) - assert fitted_model_from_disk._client is None - assert fitted_model_from_disk.client is None - assert fitted_model_from_disk.client_ == client - - # client will always be None after unpickling - if set_client: - from_disk_params = fitted_model_from_disk.get_params() - from_disk_params.pop("client", None) - dask_params = dask_model.get_params() - dask_params.pop("client", None) - assert from_disk_params == dask_params - else: - assert fitted_model_from_disk.get_params() == dask_model.get_params() - assert local_fitted_model_from_disk.get_params() == local_model.get_params() - - preds_orig = dask_model.predict(dX).compute() - preds_loaded_model = fitted_model_from_disk.predict(dX).compute() - assert_eq(preds_orig, preds_loaded_model) - - preds_orig_local = local_model.predict(X) - preds_loaded_model_local = local_fitted_model_from_disk.predict(X) - assert_eq(preds_orig_local, preds_loaded_model_local) +def test_model_and_local_version_are_picklable_whether_or_not_client_set_explicitly(serializer, task, set_client, listen_port, tmp_path): + + with LocalCluster(n_workers=2, threads_per_worker=1) as cluster1: + with Client(cluster1) as client1: + + # data on cluster1 + if task == 'ranking': + X_1, _, _, _, dX_1, dy_1, _, dg_1 = _create_ranking_data( + output='array', + group=None + ) + else: + X_1, _, _, dX_1, dy_1, _ = _create_data( + objective=task, + output='array', + ) + dg_1 = None + + with LocalCluster(n_workers=2, threads_per_worker=1) as cluster2: + with Client(cluster2) as client2: + + # create identical data on cluster2 + if task == 'ranking': + X_2, _, _, _, dX_2, dy_2, _, dg_2 = _create_ranking_data( + output='array', + group=None + ) + else: + X_2, _, _, dX_2, dy_2, _ = _create_data( + objective=task, + output='array', + ) + dg_2 = None + + if task == 'ranking': + model_factory = lgb.DaskLGBMRanker + elif task == 'classification': + model_factory = lgb.DaskLGBMClassifier + elif task == 'regression': + model_factory = lgb.DaskLGBMRegressor + + params = { + "time_out": 5, + "local_listen_port": listen_port, + "n_estimators": 1, + "num_leaves": 2 + } + + # at this point, the result of default_client() is client2 since it was the most recently + # created. So setting client to client1 here to test that you can select a non-default client + assert default_client() == client2 + if set_client: + params.update({"client": client1}) + + # unfitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + dask_model = model_factory(**params) + local_model = dask_model.to_local() + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + assert "client" not in local_model.get_params() + assert getattr(local_model, "client", None) is None + + tmp_file = str(tmp_path / "model-1.pkl") + _pickle( + obj=dask_model, + filepath=tmp_file, + serializer=serializer + ) + model_from_disk = _unpickle( + filepath=tmp_file, + serializer=serializer + ) + + local_tmp_file = str(tmp_path / "local-model-1.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file, + serializer=serializer + ) + local_model_from_disk = _unpickle( + filepath=local_tmp_file, + serializer=serializer + ) + + assert model_from_disk._client is None + assert model_from_disk.client is None + + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + + with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): + dask_model.client_ + + # client will always be None after unpickling + if set_client: + from_disk_params = model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert model_from_disk.get_params() == dask_model.get_params() + assert local_model_from_disk.get_params() == local_model.get_params() + + # fitted model should survive pickling round trip, and pickling + # shouldn't have side effects on the model object + if set_client: + dask_model.fit(dX_1, dy_1, group=dg_1) + else: + dask_model.fit(dX_2, dy_2, group=dg_2) + local_model = dask_model.to_local() + + assert "client" not in local_model.get_params() + with pytest.raises(AttributeError): + local_model._client + local_model.client + local_model.client_ + + tmp_file2 = str(tmp_path / "model-2.pkl") + _pickle( + obj=dask_model, + filepath=tmp_file2, + serializer=serializer + ) + fitted_model_from_disk = _unpickle( + filepath=tmp_file2, + serializer=serializer + ) + + local_tmp_file2 = str(tmp_path / "local-model-2.pkl") + _pickle( + obj=local_model, + filepath=local_tmp_file2, + serializer=serializer + ) + local_fitted_model_from_disk = _unpickle( + filepath=local_tmp_file2, + serializer=serializer + ) + + if set_client: + assert dask_model._client == client1 + assert dask_model.client == client1 + assert dask_model.client_ == client1 + else: + assert dask_model._client is None + assert dask_model.client is None + assert dask_model.client_ == default_client() + assert dask_model.client_ == client2 + + assert isinstance(fitted_model_from_disk, model_factory) + assert fitted_model_from_disk._client is None + assert fitted_model_from_disk.client is None + assert fitted_model_from_disk.client_ == default_client() + assert fitted_model_from_disk.client_ == client2 + + # client will always be None after unpickling + if set_client: + from_disk_params = fitted_model_from_disk.get_params() + from_disk_params.pop("client", None) + dask_params = dask_model.get_params() + dask_params.pop("client", None) + assert from_disk_params == dask_params + else: + assert fitted_model_from_disk.get_params() == dask_model.get_params() + assert local_fitted_model_from_disk.get_params() == local_model.get_params() + + if set_client: + preds_orig = dask_model.predict(dX_1).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX_1).compute() + preds_orig_local = local_model.predict(X_1) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X_1) + else: + preds_orig = dask_model.predict(dX_2).compute() + preds_loaded_model = fitted_model_from_disk.predict(dX_2).compute() + preds_orig_local = local_model.predict(X_2) + preds_loaded_model_local = local_fitted_model_from_disk.predict(X_2) + + assert_eq(preds_orig, preds_loaded_model) + assert_eq(preds_orig_local, preds_loaded_model_local) def test_find_open_port_works(): @@ -774,7 +823,15 @@ def f(part): assert 'foo' in str(info.value) -def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(): +@pytest.mark.parametrize( + "classes", + [ + (lgb.DaskLGBMClassifier, lgb.LGBMClassifier), + (lgb.DaskLGBMRegressor, lgb.LGBMRegressor), + (lgb.DaskLGBMRanker, lgb.LGBMRanker) + ] +) +def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes): def _compare_spec(dask_cls, sklearn_cls): dask_spec = inspect.getfullargspec(dask_cls) sklearn_spec = inspect.getfullargspec(sklearn_cls) @@ -789,6 +846,4 @@ def _compare_spec(dask_cls, sklearn_cls): assert dask_spec.args[-1] == 'client' assert dask_spec.defaults[-1] is None - _compare_spec(lgb.DaskLGBMClassifier, lgb.LGBMClassifier) - _compare_spec(lgb.DaskLGBMRegressor, lgb.LGBMRegressor) - _compare_spec(lgb.DaskLGBMRanker, lgb.LGBMRanker) + _compare_spec(classes[0], classes[1]) From dd3cfe02d749a42a862ae17346bb606aeaa56757 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 2 Feb 2021 16:41:14 -0600 Subject: [PATCH 17/18] linting --- python-package/lightgbm/dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 0ba89df2cee0..b1a7e2d80d64 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -28,7 +28,7 @@ def _get_dask_client(client: Optional[Client]) -> Client: - """Choose a Dask client to use + """Choose a Dask client to use. Parameters ---------- From 876bfe5f6c756e946594e3316fec716ae358de18 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 2 Feb 2021 22:05:31 -0600 Subject: [PATCH 18/18] simplify --- python-package/lightgbm/dask.py | 3 --- tests/python_package_test/test_dask.py | 27 ++++++++++++-------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index b1a7e2d80d64..d8945fa5fa38 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -453,9 +453,6 @@ def _predict( class _DaskLGBMModel: - # self._client is set in the constructor of classes that use this mixin - _client: Optional[Client] = None - @property def client_(self) -> Client: """Dask client. diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index e004c9cf933d..cc9fa3adb184 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -832,18 +832,15 @@ def f(part): ] ) def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes): - def _compare_spec(dask_cls, sklearn_cls): - dask_spec = inspect.getfullargspec(dask_cls) - sklearn_spec = inspect.getfullargspec(sklearn_cls) - assert dask_spec.varargs == sklearn_spec.varargs - assert dask_spec.varkw == sklearn_spec.varkw - assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs - assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults - - # "client" should be the only different, and the final argument - assert dask_spec.args[:-1] == sklearn_spec.args - assert dask_spec.defaults[:-1] == sklearn_spec.defaults - assert dask_spec.args[-1] == 'client' - assert dask_spec.defaults[-1] is None - - _compare_spec(classes[0], classes[1]) + dask_spec = inspect.getfullargspec(classes[0]) + sklearn_spec = inspect.getfullargspec(classes[1]) + assert dask_spec.varargs == sklearn_spec.varargs + assert dask_spec.varkw == sklearn_spec.varkw + assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs + assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults + + # "client" should be the only different, and the final argument + assert dask_spec.args[:-1] == sklearn_spec.args + assert dask_spec.defaults[:-1] == sklearn_spec.defaults + assert dask_spec.args[-1] == 'client' + assert dask_spec.defaults[-1] is None