Skip to content

Commit

Permalink
Flaml: fix lgbm reproducibility (#1369)
Browse files Browse the repository at this point in the history
* fix: Fixed bug where every underlying LGBMRegressor or LGBMClassifier had n_estimators = 1

* test: Added test showing case where FLAMLised CatBoostModel result isn't reproducible

* fix: Fixing issue where callbacks cause LGBM results to not be reproducible

* Update test/automl/test_regression.py

Co-authored-by: Li Jiang <[email protected]>

* fix: Adding back the LGBM EarlyStopping

* refactor: Fix tweaked to ensure other models aren't likely to be affected

* test: Fixed test to allow reproduced results to be better than the FLAML results, when LGBM earlystopping is involved

---------

Co-authored-by: Daniel Grindrod <[email protected]>
Co-authored-by: Li Jiang <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2024
1 parent 7644958 commit 5a74227
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 6 deletions.
7 changes: 3 additions & 4 deletions flaml/automl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,18 +1585,17 @@ def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs):
callbacks = None
if callbacks is None:
self._fit(X_train, y_train, **kwargs)
else:
self._fit(X_train, y_train, callbacks=callbacks, **kwargs)
if callbacks is None:
# for xgboost>=1.6.0, pop callbacks to enable pickle
callbacks = self.params.pop("callbacks")
self._model.set_params(callbacks=callbacks[:-1])
else:
self._fit(X_train, y_train, callbacks=callbacks, **kwargs)
best_iteration = (
getattr(self._model.get_booster(), "best_iteration", None)
if isinstance(self, XGBoostSklearnEstimator)
else self._model.best_iteration_
)
if best_iteration is not None:
if best_iteration is not None and best_iteration > 0:
self._model.set_params(n_estimators=best_iteration + 1)
else:
self._fit(X_train, y_train, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion test/automl/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def test_reproducibility_of_classification_models(estimator: str):
"extra_tree",
"histgb",
"kneighbor",
# "lgbm",
"lgbm",
# "lrl1",
"lrl2",
"svc",
Expand Down
49 changes: 48 additions & 1 deletion test/automl/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,52 @@ def test_reproducibility_of_catboost_regression_model():
assert pytest.approx(val_loss_flaml) == reproduced_val_loss


def test_reproducibility_of_lgbm_regression_model():
"""FLAML finds the best model for a given dataset, which it then provides to users.
However, there are reported issues around LGBMs - see here:
https://github.com/microsoft/FLAML/issues/1368
In this test we take the best LGBM regression model which FLAML provided us, and then retrain and test it on the
same folds, to verify that the result is reproducible.
"""
automl = AutoML()
automl_settings = {
"time_budget": 3,
"task": "regression",
"n_jobs": 1,
"estimator_list": ["lgbm"],
"eval_method": "cv",
"n_splits": 9,
"metric": "r2",
"keep_search_state": True,
"skip_transform": True,
"retrain_full": True,
}
X, y = fetch_california_housing(return_X_y=True, as_frame=True)
automl.fit(X_train=X, y_train=y, **automl_settings)
best_model = automl.model
assert best_model is not None
config = best_model.get_params()
val_loss_flaml = automl.best_result["val_loss"]

# Take the best model, and see if we can reproduce the best result
reproduced_val_loss, metric_for_logging, train_time, pred_time = automl._state.task.evaluate_model_CV(
config=config,
estimator=best_model,
X_train_all=automl._state.X_train_all,
y_train_all=automl._state.y_train_all,
budget=None,
kf=automl._state.kf,
eval_metric="r2",
best_val_loss=None,
cv_score_agg_func=None,
log_training_metric=False,
fit_kwargs=None,
free_mem_ratio=0,
)
assert pytest.approx(val_loss_flaml) == reproduced_val_loss or val_loss_flaml > reproduced_val_loss


@pytest.mark.parametrize(
"estimator",
[
Expand All @@ -347,7 +393,7 @@ def test_reproducibility_of_catboost_regression_model():
"extra_tree",
"histgb",
"kneighbor",
# "lgbm",
"lgbm",
"rf",
"xgboost",
"xgb_limitdepth",
Expand Down Expand Up @@ -376,6 +422,7 @@ def test_reproducibility_of_underlying_regression_models(estimator: str):
"metric": "r2",
"keep_search_state": True,
"skip_transform": True,
"retrain_full": False,
}
X, y = fetch_california_housing(return_X_y=True, as_frame=True)
automl.fit(X_train=X, y_train=y, **automl_settings)
Expand Down

0 comments on commit 5a74227

Please sign in to comment.