From 1e175c9f2ed53036eab6dd0d8843202c6f63946c Mon Sep 17 00:00:00 2001 From: Jovan Veljanoski Date: Wed, 7 Feb 2024 16:38:03 +0100 Subject: [PATCH] fix[ml]: adjust tests to reflect latest apis of 3rd party libraries (xgboost, lightgbm) --- ci/conda-env.yml | 2 +- packages/vaex-ml/vaex/ml/lightgbm.py | 13 +++++++++---- tests/ml/xgboost_test.py | 5 +---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/ci/conda-env.yml b/ci/conda-env.yml index 635349147b..34138c94d7 100644 --- a/ci/conda-env.yml +++ b/ci/conda-env.yml @@ -20,7 +20,7 @@ dependencies: - h5py - httpx # for testing with starlette/fastapi - ipyvolume=0.6.0a6 -- lightgbm +- lightgbm>=4.0.0 - matplotlib-base - nest-asyncio<1.5.2 - notebook diff --git a/packages/vaex-ml/vaex/ml/lightgbm.py b/packages/vaex-ml/vaex/ml/lightgbm.py index 5fb83ea458..4dbec9d42f 100644 --- a/packages/vaex-ml/vaex/ml/lightgbm.py +++ b/packages/vaex-ml/vaex/ml/lightgbm.py @@ -86,7 +86,7 @@ def transform(self, df): copy.add_virtual_column(self.prediction_name, expression, unique=False) return copy - def fit(self, df, valid_sets=None, valid_names=None, early_stopping_rounds=None, evals_result=None, verbose_eval=None, **kwargs): + def fit(self, df, valid_sets=None, valid_names=None, early_stopping_rounds=None, evals_result=None, verbose_eval=False, **kwargs): """Fit the LightGBMModel to the DataFrame. The model will train until the validation score stops improving. @@ -112,14 +112,19 @@ def fit(self, df, valid_sets=None, valid_names=None, early_stopping_rounds=None, else: valid_sets = () + callbacks = [ + lightgbm.callback.record_evaluation(eval_result=evals_result) if evals_result is not None else None, + lightgbm.callback.early_stopping(stopping_rounds=early_stopping_rounds) if early_stopping_rounds else None, + lightgbm.callback.log_evaluation() if verbose_eval else None + ] + callbacks = [callback for callback in callbacks if callback is not None] + self.booster = lightgbm.train(params=self.params, train_set=dtrain, num_boost_round=self.num_boost_round, valid_sets=valid_sets, valid_names=valid_names, - early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, - verbose_eval=verbose_eval, + callbacks=callbacks, **kwargs) def predict(self, df, **kwargs): diff --git a/tests/ml/xgboost_test.py b/tests/ml/xgboost_test.py index 9a1a5dd371..c9c264db98 100644 --- a/tests/ml/xgboost_test.py +++ b/tests/ml/xgboost_test.py @@ -17,7 +17,6 @@ 'objective': 'multi:softmax', # learning task objective 'num_class': 3, # number of target classes (if classification) 'random_state': 42, # fixes the seed, for reproducibility - 'silent': 1, # silent mode 'n_jobs': -1 # cpu cores used } @@ -32,14 +31,13 @@ 'min_child_weight': 1, # minimum sum of instance weight (hessian) needed in a child 'objective': 'reg:linear', # learning task objective 'random_state': 42, # fixes the seed, for reproducibility - 'silent': 1, # silent mode 'n_jobs': -1 # cpu cores used } def test_xgboost(df_iris): ds = df_iris - ds_train, ds_test = ds.ml.train_test_split(test_size=0.2, verbose=False) + ds_train, ds_test = ds.ml.train_test_split(test_size=0.1, verbose=False) features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] booster = vaex.ml.xgboost.XGBoostModel(num_boost_round=10, params=params_multiclass, @@ -104,7 +102,6 @@ def test_xgboost_validation_set(df_example): # fit the booster - including saving the history of the validation sets booster.fit(train, evals=[(train, 'train'), (test, 'test')], early_stopping_rounds=2, evals_result=history) - assert booster.booster.best_ntree_limit == 10 assert booster.booster.best_iteration == 9 assert len(history['train']['rmse']) == 10 assert len(history['test']['rmse']) == 10