Skip to content

Commit

Permalink
add GradientBoosting 'estimators'
Browse files Browse the repository at this point in the history
  • Loading branch information
thierrymoudiki committed Aug 3, 2024
1 parent 975b22f commit 26aa160
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 4 deletions.
9 changes: 8 additions & 1 deletion examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,30 @@
clf1 = ub.GBDTClassifier(model_type='xgboost')
#clf2 = ub.GBDTClassifier(model_type='catboost')
clf3 = ub.GBDTClassifier(model_type='lightgbm')
clf4 = ub.GBDTClassifier(model_type='gradientboosting',
colsample=0.9)

# Fit the model
clf1.fit(X_train, y_train)
#clf2.fit(X_train, y_train)
clf3.fit(X_train, y_train)
clf4.fit(X_train, y_train)

# Predict on the test set
y_pred1 = clf1.predict(X_test)
#y_pred2 = clf2.predict(X_test)
y_pred3 = clf3.predict(X_test)
y_pred4 = clf4.predict(X_test)

# Evaluate the model
accuracy1 = accuracy_score(y_test, y_pred1)
#accuracy2 = accuracy_score(y_test, y_pred2)
accuracy3 = accuracy_score(y_test, y_pred3)
accuracy4 = accuracy_score(y_test, y_pred4)
print(f"Classification Accuracy xgboost: {accuracy1:.2f}")
#print(f"Classification Accuracy catboost: {accuracy2:.2f}")
print(f"Classification Accuracy lightgbm: {accuracy3:.2f}")
print(f"Classification Accuracy gradientboosting: {accuracy4:.2f}")
print(f"CV xgboost: {cross_val_score(clf1, X_train, y_train)}")
print(f"CV lightgbm: {cross_val_score(clf3, X_train, y_train)}")
print(f"CV lightgbm: {cross_val_score(clf3, X_train, y_train)}")
print(f"CV gradientboosting: {cross_val_score(clf4, X_train, y_train)}")
9 changes: 8 additions & 1 deletion examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,30 @@
regr1 = ub.GBDTregr(model_type='xgboost')
#regr2 = ub.GBDTregr(model_type='catboost')
regr3 = ub.GBDTregr(model_type='lightgbm')
regr4 = ub.GBDTregr(model_type='gradientboosting',
colsample=0.9)

# Fit the model
regr1.fit(X_train, y_train)
#regr2.fit(X_train, y_train)
regr3.fit(X_train, y_train)
regr4.fit(X_train, y_train)

# Predict on the test set
y_pred1 = regr1.predict(X_test)
#y_pred2 = regr2.predict(X_test)
y_pred3 = regr3.predict(X_test)
y_pred4 = regr4.predict(X_test)

# Evaluate the model
mse1 = mean_squared_error(y_test, y_pred1)
#mse2 = mean_squared_error(y_test, y_pred2)
mse3 = mean_squared_error(y_test, y_pred3)
mse4 = mean_squared_error(y_test, y_pred4)
print(f"Regression Mean Squared Error xgboost: {mse1:.2f}")
#print(f"Regression Mean Squared Error catboost: {mse2:.2f}")
print(f"Regression Mean Squared Error lightgbm: {mse3:.2f}")
print(f"Regression Mean Squared Error gradientboosting: {mse4:.2f}")
print(f"CV xgboost: {cross_val_score(regr1, X_train, y_train)}")
print(f"CV lightgbm: {cross_val_score(regr3, X_train, y_train)}")
print(f"CV lightgbm: {cross_val_score(regr3, X_train, y_train)}")
print(f"CV gradientboosting: {cross_val_score(regr4, X_train, y_train)}")
17 changes: 17 additions & 0 deletions unifiedbooster/gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ class GBDT(BaseEstimator):
Attributes:
model_type: str
type of gradient boosting algorithm: 'xgboost', 'lightgbm',
'catboost', 'gradientboosting'
n_estimators: int
maximum number of trees that can be built
Expand Down Expand Up @@ -86,6 +90,17 @@ def __init__(
"bootstrap_type": "Bernoulli",
**kwargs,
}
elif self.model_type == "gradientboosting":
self.params = {
"n_estimators": self.n_estimators,
"learning_rate": self.learning_rate,
"subsample": self.rowsample,
"max_features": self.colsample,
"max_depth": self.max_depth,
"verbose": self.verbose,
"random_state": self.seed,
**kwargs,
}

def fit(self, X, y, **kwargs):
"""Fit custom model to training data (X, y).
Expand All @@ -112,6 +127,8 @@ def fit(self, X, y, **kwargs):
self.n_classes_ = len(
self.classes_
) # for compatibility with sklearn
if getattr(self, "model_type") == "gradientboosting":
setattr(self, "model").max_features *= X.shape[1]
return getattr(self, "model").fit(X, y, **kwargs)

def predict(self, X):
Expand Down
8 changes: 7 additions & 1 deletion unifiedbooster/gbdt_classification.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from .gbdt import GBDT
from sklearn.base import ClassifierMixin
from xgboost import XGBClassifier

try:
from catboost import CatBoostClassifier
except:
print("catboost package can't be built")
from lightgbm import LGBMClassifier
from sklearn.ensemble import GradientBoostingClassifier


class GBDTClassifier(GBDT, ClassifierMixin):
"""GBDT Classification model
Attributes:
model_type: str
type of gradient boosting algorithm: 'xgboost', 'lightgbm',
'catboost', 'gradientboosting'
n_estimators: int
maximum number of trees that can be built
Expand Down Expand Up @@ -108,6 +112,8 @@ def __init__(
self.model = CatBoostClassifier(**self.params)
elif model_type == "lightgbm":
self.model = LGBMClassifier(**self.params)
elif model_type == "gradientboosting":
self.model = GradientBoostingClassifier(**self.params)
else:
raise ValueError(f"Unknown model_type: {model_type}")

Expand Down
8 changes: 7 additions & 1 deletion unifiedbooster/gbdt_regression.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from .gbdt import GBDT
from sklearn.base import RegressorMixin
from xgboost import XGBRegressor

try:
from catboost import CatBoostRegressor
except:
print("catboost package can't be built")
from lightgbm import LGBMRegressor
from sklearn.ensemble import GradientBoostingRegressor


class GBDTRegressor(GBDT, RegressorMixin):
"""GBDT Regression model
Attributes:
model_type: str
type of gradient boosting algorithm: 'xgboost', 'lightgbm',
'catboost', 'gradientboosting'
n_estimators: int
maximum number of trees that can be built
Expand Down Expand Up @@ -108,5 +112,7 @@ def __init__(
self.model = CatBoostRegressor(**self.params)
elif model_type == "lightgbm":
self.model = LGBMRegressor(**self.params)
elif model_type == "gradientboosting":
self.model = GradientBoostingRegressor(**self.params)
else:
raise ValueError(f"Unknown model_type: {model_type}")

0 comments on commit 26aa160

Please sign in to comment.