Skip to content

Commit

Permalink
fix: Fixing the random state of ElasticNetClassifier by default, to e…
Browse files Browse the repository at this point in the history
…nsure reproduciblity. Also included elasticnet in reproducibility tests (#1374)

Co-authored-by: Daniel Grindrod <[email protected]>
Co-authored-by: Li Jiang <[email protected]>
  • Loading branch information
3 people authored Oct 29, 2024
1 parent 69da685 commit 72881d3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions flaml/automl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,11 @@ def config2params(self, config: dict) -> dict:

def __init__(self, task="regression", **config):
super().__init__(task, **config)
self.params.update(
{
"random_state": config.get("random_seed", 10242048),
}
)
assert self._task.is_regression(), "ElasticNet for regression task only"
self.estimator_class = ElasticNet

Expand Down
3 changes: 2 additions & 1 deletion test/automl/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_multioutput():
"estimator",
[
"catboost",
"enet",
"extra_tree",
"histgb",
"kneighbor",
Expand Down Expand Up @@ -342,6 +343,7 @@ def test_reproducibility_of_catboost_regression_model():
"estimator",
[
"catboost",
"enet",
"extra_tree",
"histgb",
"kneighbor",
Expand Down Expand Up @@ -385,7 +387,6 @@ def test_reproducibility_of_underlying_regression_models(estimator: str):
automl._state.X_train_all, automl._state.y_train_all, automl._state.kf, best_model.model, "regression"
)
)

assert pytest.approx(val_loss_flaml) == reproduced_val_loss_underlying_model


Expand Down

0 comments on commit 72881d3

Please sign in to comment.