Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add support for sklearn HistGradientBoostingEstimator #1230

Merged
merged 13 commits into from
Oct 31, 2023
1 change: 1 addition & 0 deletions flaml/automl/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .histgb import HistGradientBoostingEstimator
75 changes: 75 additions & 0 deletions flaml/automl/contrib/histgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
try:
from sklearn.ensemble import HistGradientBoostingClassifier, HistGradientBoostingRegressor
except ImportError:
pass

from flaml import tune
from flaml.automl.model import SKLearnEstimator
from flaml.automl.task import Task


class HistGradientBoostingEstimator(SKLearnEstimator):
"""The class for tuning Histogram Gradient Boosting."""

ITER_HP = "max_iter"
HAS_CALLBACK = False
DEFAULT_ITER = 100

@classmethod
def search_space(cls, data_size: int, task, **params) -> dict:
upper = max(5, min(32768, int(data_size[0]))) # upper must be larger than lower
return {
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=upper),
"init_value": 4,
"low_cost_init_value": 4,
},
"max_leaves": {
"domain": tune.lograndint(lower=4, upper=upper),
"init_value": 4,
"low_cost_init_value": 4,
},
"min_samples_leaf": {
"domain": tune.lograndint(lower=2, upper=2**7 + 1),
"init_value": 20,
},
"learning_rate": {
"domain": tune.loguniform(lower=1 / 1024, upper=1.0),
"init_value": 0.1,
},
"log_max_bin": { # log transformed with base 2, <= 256
"domain": tune.lograndint(lower=3, upper=9),
"init_value": 8,
},
"l2_regularization": {
"domain": tune.loguniform(lower=1 / 1024, upper=1024),
"init_value": 1.0,
},
}

def config2params(self, config: dict) -> dict:
params = super().config2params(config)
if "log_max_bin" in params:
params["max_bins"] = (1 << params.pop("log_max_bin")) - 1
if "max_leaves" in params:
params["max_leaf_nodes"] = params.get("max_leaf_nodes", params.pop("max_leaves"))
if "n_estimators" in params:
params["max_iter"] = params.get("max_iter", params.pop("n_estimators"))
if "random_state" not in params:
params["random_state"] = 24092023
if "n_jobs" in params:
params.pop("n_jobs")
return params

def __init__(
self,
task: Task,
**config,
):
super().__init__(task, **config)
self.params["verbose"] = 0

if self._task.is_classification():
self.estimator_class = HistGradientBoostingClassifier
else:
self.estimator_class = HistGradientBoostingRegressor
52 changes: 30 additions & 22 deletions flaml/automl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,42 @@
# * Copyright (c) FLAML authors. All rights reserved.
# * Licensed under the MIT License. See LICENSE file in the
# * project root for license information.
import logging
import math
import os
import shutil
import signal
import sys
import time
from contextlib import contextmanager
from functools import partial
import signal
import os
from typing import Callable, List, Union

import numpy as np
import time
import logging
import shutil
import sys
import math

thinkall marked this conversation as resolved.
Show resolved Hide resolved
from flaml import tune
from flaml.automl.data import (
group_counts,
)
from flaml.automl.task.factory import task_factory
from flaml.automl.task.task import (
Task,
NLG_TASKS,
SEQCLASSIFICATION,
SEQREGRESSION,
TOKENCLASSIFICATION,
SUMMARIZATION,
NLG_TASKS,
TOKENCLASSIFICATION,
Task,
)
from flaml.automl.task.factory import task_factory

try:
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.ensemble import ExtraTreesRegressor, ExtraTreesClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.ensemble import (
ExtraTreesClassifier,
ExtraTreesRegressor,
RandomForestClassifier,
RandomForestRegressor,
)
from sklearn.linear_model import LogisticRegression
from xgboost import __version__ as xgboost_version
except ImportError:
pass
Expand All @@ -41,13 +47,14 @@
except ImportError:
pass

from flaml.automl.spark import psDataFrame, sparkDataFrame, psSeries, ERROR as SPARK_ERROR, DataFrame, Series
from flaml.automl.spark.utils import len_labels, to_pandas_on_spark
from flaml.automl.spark import ERROR as SPARK_ERROR
from flaml.automl.spark import DataFrame, Series, psDataFrame, psSeries, sparkDataFrame
from flaml.automl.spark.configs import (
ParamList_LightGBM_Classifier,
ParamList_LightGBM_Regressor,
ParamList_LightGBM_Ranker,
ParamList_LightGBM_Regressor,
)
from flaml.automl.spark.utils import len_labels, to_pandas_on_spark

if DataFrame is not None:
from pandas import to_datetime
Expand All @@ -62,7 +69,7 @@
resource = None

try:
from lightgbm import LGBMClassifier, LGBMRegressor, LGBMRanker
from lightgbm import LGBMClassifier, LGBMRanker, LGBMRegressor
except ImportError:
LGBMClassifier = LGBMRegressor = LGBMRanker = None

Expand Down Expand Up @@ -320,8 +327,7 @@ def score(self, X_val: DataFrame, y_val: Series, **kwargs):
Returns:
The evaluation score on the validation dataset.
"""
from .ml import metric_loss_score
from .ml import is_min_metric
from .ml import is_min_metric, metric_loss_score

if self._model is not None:
if self._task == "rank":
Expand Down Expand Up @@ -759,7 +765,7 @@ def no_cuda(self):
return not self._kwargs.get("gpu_per_trial")

def _set_training_args(self, **kwargs):
from .nlp.utils import date_str, Counter
from .nlp.utils import Counter, date_str

for key, val in kwargs.items():
assert key not in self.params, (
Expand Down Expand Up @@ -873,10 +879,10 @@ def tokenizer(self):

@property
def data_collator(self):
from flaml.automl.task.task import Task
from flaml.automl.nlp.huggingface.data_collator import (
task_to_datacollator_class,
)
from flaml.automl.task.task import Task

data_collator_class = task_to_datacollator_class.get(
self._task.name if isinstance(self._task, Task) else self._task
Expand Down Expand Up @@ -917,6 +923,7 @@ def fit(

from transformers import TrainerCallback
from transformers.trainer_utils import set_seed

from .nlp.huggingface.trainer import TrainerForAuto

try:
Expand Down Expand Up @@ -1146,6 +1153,7 @@ def score(self, X_val: DataFrame, y_val: Series, **kwargs):
def predict(self, X, **pred_kwargs):
import transformers
from datasets import Dataset

from .nlp.huggingface.utils import postprocess_prediction_and_true

transformers.logging.set_verbosity_error()
Expand Down
51 changes: 27 additions & 24 deletions flaml/automl/task/generic_task.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,44 @@
import logging
import time
from typing import List, Optional

import numpy as np
from flaml.automl.data import TS_TIMESTAMP_COL, concat
from flaml.automl.ml import EstimatorSubclass, get_val_loss, default_cv_score_agg_func

from flaml.automl.task.task import (
Task,
get_classification_objective,
TS_FORECAST,
TS_FORECASTPANEL,
)
from flaml.config import RANDOM_SEED
from flaml.automl.spark import ps, psDataFrame, psSeries, pd
from flaml.automl.data import TS_TIMESTAMP_COL, concat
from flaml.automl.ml import EstimatorSubclass, default_cv_score_agg_func, get_val_loss
from flaml.automl.spark import pd, ps, psDataFrame, psSeries
from flaml.automl.spark.utils import (
iloc_pandas_on_spark,
len_labels,
set_option,
spark_kFold,
train_test_split_pyspark,
unique_pandas_on_spark,
unique_value_first_index,
len_labels,
set_option,
)
from flaml.automl.task.task import (
TS_FORECAST,
TS_FORECASTPANEL,
Task,
get_classification_objective,
)
from flaml.config import RANDOM_SEED

try:
from scipy.sparse import issparse
except ImportError:
pass
try:
from sklearn.utils import shuffle
from sklearn.model_selection import (
train_test_split,
RepeatedStratifiedKFold,
RepeatedKFold,
GroupKFold,
TimeSeriesSplit,
GroupShuffleSplit,
RepeatedKFold,
RepeatedStratifiedKFold,
StratifiedGroupKFold,
TimeSeriesSplit,
train_test_split,
)
from sklearn.utils import shuffle
except ImportError:
pass

Expand All @@ -49,19 +50,20 @@ class GenericTask(Task):
def estimators(self):
if self._estimators is None:
# put this into a function to avoid circular dependency
from flaml.automl.contrib.histgb import HistGradientBoostingEstimator
from flaml.automl.model import (
XGBoostSklearnEstimator,
XGBoostLimitDepthEstimator,
RandomForestEstimator,
LGBMEstimator,
LRL1Classifier,
LRL2Classifier,
CatBoostEstimator,
ExtraTreesEstimator,
KNeighborsEstimator,
LGBMEstimator,
LRL1Classifier,
LRL2Classifier,
RandomForestEstimator,
SparkLGBMEstimator,
TransformersEstimator,
TransformersEstimatorModelSelection,
SparkLGBMEstimator,
XGBoostLimitDepthEstimator,
XGBoostSklearnEstimator,
)

self._estimators = {
Expand All @@ -77,6 +79,7 @@ def estimators(self):
"kneighbor": KNeighborsEstimator,
"transformer": TransformersEstimator,
"transformer_ms": TransformersEstimatorModelSelection,
"histgb": HistGradientBoostingEstimator,
}
return self._estimators

Expand Down
25 changes: 21 additions & 4 deletions test/automl/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import unittest
from datetime import datetime

import numpy as np
import pandas as pd
import scipy.sparse
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import pandas as pd
from datetime import datetime
from flaml import AutoML

from flaml import AutoML, tune
from flaml.automl.model import LGBMEstimator
from flaml import tune


class MyLargeLGBM(LGBMEstimator):
Expand Down Expand Up @@ -194,6 +195,22 @@ def test_preprocess(self):
automl.fit(X, y, **automl_settings)
del automl

automl = AutoML()
automl_settings = {
"time_budget": 3,
"task": "classification",
"n_jobs": 1,
"estimator_list": ["histgb"],
"eval_method": "cv",
"n_splits": 3,
"metric": "accuracy",
"log_training_metric": True,
# "verbose": 4,
"ensemble": True,
}
automl.fit(X, y, **automl_settings)
del automl

def test_binary(self):
automl_experiment = AutoML()
automl_settings = {
Expand Down
22 changes: 15 additions & 7 deletions test/test_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from sklearn.datasets import make_classification
from datetime import datetime

import numpy as np
from pandas import DataFrame
from datetime import datetime
from sklearn.datasets import make_classification

from flaml.automl.contrib.histgb import HistGradientBoostingEstimator
from flaml.automl.model import (
KNeighborsEstimator,
LRL2Classifier,
BaseEstimator,
LGBMEstimator,
CatBoostEstimator,
XGBoostEstimator,
KNeighborsEstimator,
LGBMEstimator,
LRL2Classifier,
RandomForestEstimator,
XGBoostEstimator,
)
from flaml.automl.time_series import Prophet, ARIMA, LGBM_TS, TimeSeriesDataset
from flaml.automl.time_series import ARIMA, LGBM_TS, Prophet, TimeSeriesDataset


def test_lrl2():
Expand Down Expand Up @@ -90,6 +93,11 @@ def test_prep():
rf.fit(X, y)
print(rf.feature_names_in_)
print(rf.feature_importances_)
hgb = HistGradientBoostingEstimator(task="regression", n_estimators=4, max_leaves=4)
hgb.fit(X, y)
hgb.predict(X)
print(hgb.feature_names_in_)
print(hgb.feature_importances_)

prophet = Prophet()
try:
Expand Down
1 change: 1 addition & 0 deletions website/docs/Use-Cases/Task-Oriented-AutoML.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ The estimator list can contain one or more estimator names, each corresponding t
it uses a fixed random_state by default.
- 'extra_tree': ExtraTreesEstimator for task "classification", "regression", "ts_forecast" and "ts_forecast_classification". Hyperparameters: n_estimators, max_features, max_leaves, criterion (for classification only). Starting from v1.1.0,
it uses a fixed random_state by default.
- 'histgb': HistGradientBoostingEstimator for task "classification", "regression", "ts_forecast" and "ts_forecast_classification". Hyperparameters: n_estimators, max_leaves, min_samples_leaf, learning_rate, log_max_bin (logarithm of (max_bin + 1) with base 2), l2_regularization. It uses a fixed random_state by default.
- 'lrl1': LRL1Classifier (sklearn.LogisticRegression with L1 regularization) for task "classification". Hyperparameters: C.
- 'lrl2': LRL2Classifier (sklearn.LogisticRegression with L2 regularization) for task "classification". Hyperparameters: C.
- 'catboost': CatBoostEstimator for task "classification" and "regression". Hyperparameters: early_stopping_rounds, learning_rate, n_estimators.
Expand Down
Loading