Skip to content

Commit

Permalink
Add support for multiclass XGBoost RF (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS authored May 20, 2020
1 parent 757585d commit 4869e00
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pip install m2cgen
| **Linear** | <ul><li>scikit-learn<ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul></li><li>lightning<ul><li>AdaGradClassifier</li><li>CDClassifier</li><li>FistaClassifier</li><li>SAGAClassifier</li><li>SAGClassifier</li><li>SDCAClassifier</li><li>SGDClassifier</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Generalized Least Squares with AR Errors (GLSAR)</li><li>Generalized Linear Models (GLM)</li><li>Ordinary Least Squares (OLS)</li><li>[Gaussian] Process Regression Using Maximum Likelihood-based Estimation (ProcessMLE)</li><li>Quantile Regression (QuantReg)</li><li>Weighted Least Squares (WLS)</li></ul><li>lightning<ul><li>AdaGradRegressor</li><li>CDRegressor</li><li>FistaRegressor</li><li>SAGARegressor</li><li>SAGRegressor</li><li>SDCARegressor</li></ul></li></ul> |
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>KernelSVC</li><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier(binary only, multiclass is not supported yet)</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
| **Boosting** | <ul><li>LGBMClassifier(gbdt/dart/goss booster only)</li><li>XGBClassifier(gbtree/gblinear booster only)</li><ul> | <ul><li>LGBMRegressor(gbdt/dart/goss booster only)</li><li>XGBRegressor(gbtree/gblinear booster only)</li></ul> |

You can find versions of packages with which compatibility is guaranteed by CI tests [here](https://github.com/BayesWitnesses/m2cgen/blob/master/requirements-test.txt#L1).
Expand Down
23 changes: 15 additions & 8 deletions m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class BaseBoostingAssembler(ModelAssembler):

classifier_names = {}
strided_layout_for_multiclass = True

def __init__(self, model, estimator_params, base_score=0):
super().__init__(model)
Expand Down Expand Up @@ -54,7 +55,8 @@ def _assemble_multi_class_output(self, estimator_params):
# Multi-class output is calculated based on discussion in
# https://github.com/dmlc/xgboost/issues/1746#issuecomment-295962863
splits = _split_estimator_params_by_classes(
estimator_params, self._output_size)
estimator_params, self._output_size,
self.strided_layout_for_multiclass)

base_score = self._base_score
exprs = [
Expand Down Expand Up @@ -112,9 +114,8 @@ class XGBoostTreeModelAssembler(BaseTreeBoostingAssembler):
classifier_names = {"XGBClassifier", "XGBRFClassifier"}

def __init__(self, model):
if type(model).__name__ == "XGBRFClassifier" and model.n_classes_ > 2:
raise RuntimeError(
"Multiclass XGBRFClassifier is not supported yet")
if type(model).__name__ == "XGBRFClassifier":
self.strided_layout_for_multiclass = False
feature_names = model.get_booster().feature_names
self._feature_name_to_idx = {
name: idx for idx, name in enumerate(feature_names or [])
Expand Down Expand Up @@ -243,7 +244,13 @@ def _assemble_tree(self, tree):
self._assemble_tree(false_child))


def _split_estimator_params_by_classes(values, n_classes):
# Splits are computed based on a comment
# https://github.com/dmlc/xgboost/issues/1746#issuecomment-267400592.
return [values[class_idx::n_classes] for class_idx in range(n_classes)]
def _split_estimator_params_by_classes(values, n_classes, strided):
if strided:
# Splits are computed based on a comment
# https://github.com/dmlc/xgboost/issues/1746#issuecomment-267400592.
return [values[class_idx::n_classes] for class_idx in range(n_classes)]
else:
values_len = len(values)
block_len = values_len // n_classes
return [values[start_block_idx:start_block_idx + block_len]
for start_block_idx in range(0, values_len, block_len)]
1 change: 1 addition & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def regression_bounded(model, test_fraction=0.02):
# XGBoost (RF)
regression(xgboost.XGBRFRegressor(**XGBOOST_PARAMS_RF)),
classification(xgboost.XGBRFClassifier(**XGBOOST_PARAMS_RF)),
classification_binary(xgboost.XGBRFClassifier(**XGBOOST_PARAMS_RF)),
# XGBoost (Large Trees)
Expand Down

0 comments on commit 4869e00

Please sign in to comment.