Skip to content

Commit

Permalink
Fix #168. Enforce float32 type for split condition values for GBT mod…
Browse files Browse the repository at this point in the history
…els created using XGBoost
  • Loading branch information
izeigerman committed Mar 30, 2020
1 parent 52c601b commit e13bea3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
2 changes: 1 addition & 1 deletion m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _assemble_tree(self, tree):
if "leaf" in tree:
return ast.NumVal(tree["leaf"])

threshold = ast.NumVal(tree["split_condition"])
threshold = ast.NumVal(np.float32(tree["split_condition"]))
split = tree["split"]
feature_idx = self._feature_name_to_idx.get(split, split)
feature_ref = ast.FeatureRef(feature_idx)
Expand Down
35 changes: 23 additions & 12 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,50 +35,51 @@


# Set of helper functions to make parametrization less verbose.
def regression(model):
def regression(model, test_fraction=0.02):
return (
model,
utils.get_regression_model_trainer(),
utils.get_regression_model_trainer(test_fraction),
REGRESSION,
)


def classification(model):
def classification(model, test_fraction=0.02):
return (
model,
utils.get_classification_model_trainer(),
utils.get_classification_model_trainer(test_fraction),
CLASSIFICATION,
)


def classification_binary(model):
def classification_binary(model, test_fraction=0.02):
return (
model,
utils.get_binary_classification_model_trainer(),
utils.get_binary_classification_model_trainer(test_fraction),
CLASSIFICATION,
)


def regression_random(model):
def regression_random(model, test_fraction=0.02):
return (
model,
utils.get_regression_random_data_model_trainer(0.01),
utils.get_regression_random_data_model_trainer(test_fraction),
REGRESSION,
)


def classification_random(model):
def classification_random(model, test_fraction=0.02):
return (
model,
utils.get_classification_random_data_model_trainer(0.01),
utils.get_classification_random_data_model_trainer(test_fraction),
CLASSIFICATION,
)


def classification_binary_random(model):
def classification_binary_random(model, test_fraction=0.02):
return (
model,
utils.get_classification_binary_random_data_model_trainer(0.01),
utils.get_classification_binary_random_data_model_trainer(
test_fraction),
CLASSIFICATION,
)

Expand All @@ -92,6 +93,8 @@ def classification_binary_random(model):
FOREST_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED)
XGBOOST_PARAMS = dict(base_score=0.6, n_estimators=10,
random_state=RANDOM_SEED)
XGBOOST_HIST_PARAMS = dict(base_score=0.6, n_estimators=10,
tree_method="hist", random_state=RANDOM_SEED)
XGBOOST_PARAMS_LINEAR = dict(base_score=0.6, n_estimators=10,
feature_selector="shuffle", booster="gblinear",
random_state=RANDOM_SEED)
Expand Down Expand Up @@ -170,6 +173,14 @@ def classification_binary_random(model):
classification(xgboost.XGBClassifier(**XGBOOST_PARAMS)),
classification_binary(xgboost.XGBClassifier(**XGBOOST_PARAMS)),
# XGBoost (tree method "hist")
regression(xgboost.XGBRegressor(**XGBOOST_HIST_PARAMS),
test_fraction=0.2),
classification(xgboost.XGBClassifier(**XGBOOST_HIST_PARAMS),
test_fraction=0.2),
classification_binary(xgboost.XGBClassifier(**XGBOOST_HIST_PARAMS),
test_fraction=0.2),
# XGBoost (LINEAR)
regression(xgboost.XGBRegressor(**XGBOOST_PARAMS_LINEAR)),
classification(xgboost.XGBClassifier(**XGBOOST_PARAMS_LINEAR)),
Expand Down

0 comments on commit e13bea3

Please sign in to comment.