From e13bea39e1f09382060ebb401175e4510efeb8aa Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Sun, 29 Mar 2020 17:06:31 -0700 Subject: [PATCH] Fix #168. Enforce float32 type for split condition values for GBT models created using XGBoost --- m2cgen/assemblers/boosting.py | 2 +- tests/e2e/test_e2e.py | 35 +++++++++++++++++++++++------------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/m2cgen/assemblers/boosting.py b/m2cgen/assemblers/boosting.py index e327b774..07e250e6 100644 --- a/m2cgen/assemblers/boosting.py +++ b/m2cgen/assemblers/boosting.py @@ -134,7 +134,7 @@ def _assemble_tree(self, tree): if "leaf" in tree: return ast.NumVal(tree["leaf"]) - threshold = ast.NumVal(tree["split_condition"]) + threshold = ast.NumVal(np.float32(tree["split_condition"])) split = tree["split"] feature_idx = self._feature_name_to_idx.get(split, split) feature_ref = ast.FeatureRef(feature_idx) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index b92185f8..8b6e8baa 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -35,50 +35,51 @@ # Set of helper functions to make parametrization less verbose. -def regression(model): +def regression(model, test_fraction=0.02): return ( model, - utils.get_regression_model_trainer(), + utils.get_regression_model_trainer(test_fraction), REGRESSION, ) -def classification(model): +def classification(model, test_fraction=0.02): return ( model, - utils.get_classification_model_trainer(), + utils.get_classification_model_trainer(test_fraction), CLASSIFICATION, ) -def classification_binary(model): +def classification_binary(model, test_fraction=0.02): return ( model, - utils.get_binary_classification_model_trainer(), + utils.get_binary_classification_model_trainer(test_fraction), CLASSIFICATION, ) -def regression_random(model): +def regression_random(model, test_fraction=0.02): return ( model, - utils.get_regression_random_data_model_trainer(0.01), + utils.get_regression_random_data_model_trainer(test_fraction), REGRESSION, ) -def classification_random(model): +def classification_random(model, test_fraction=0.02): return ( model, - utils.get_classification_random_data_model_trainer(0.01), + utils.get_classification_random_data_model_trainer(test_fraction), CLASSIFICATION, ) -def classification_binary_random(model): +def classification_binary_random(model, test_fraction=0.02): return ( model, - utils.get_classification_binary_random_data_model_trainer(0.01), + utils.get_classification_binary_random_data_model_trainer( + test_fraction), CLASSIFICATION, ) @@ -92,6 +93,8 @@ def classification_binary_random(model): FOREST_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED) XGBOOST_PARAMS = dict(base_score=0.6, n_estimators=10, random_state=RANDOM_SEED) +XGBOOST_HIST_PARAMS = dict(base_score=0.6, n_estimators=10, + tree_method="hist", random_state=RANDOM_SEED) XGBOOST_PARAMS_LINEAR = dict(base_score=0.6, n_estimators=10, feature_selector="shuffle", booster="gblinear", random_state=RANDOM_SEED) @@ -170,6 +173,14 @@ def classification_binary_random(model): classification(xgboost.XGBClassifier(**XGBOOST_PARAMS)), classification_binary(xgboost.XGBClassifier(**XGBOOST_PARAMS)), + # XGBoost (tree method "hist") + regression(xgboost.XGBRegressor(**XGBOOST_HIST_PARAMS), + test_fraction=0.2), + classification(xgboost.XGBClassifier(**XGBOOST_HIST_PARAMS), + test_fraction=0.2), + classification_binary(xgboost.XGBClassifier(**XGBOOST_HIST_PARAMS), + test_fraction=0.2), + # XGBoost (LINEAR) regression(xgboost.XGBRegressor(**XGBOOST_PARAMS_LINEAR)), classification(xgboost.XGBClassifier(**XGBOOST_PARAMS_LINEAR)),