-
Notifications
You must be signed in to change notification settings - Fork 243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #168. Enforce float32 type for split condition values for GBT models created using XGBoost #188
Conversation
…els created using XGBoost
@StrikerRUS In case if you're curious and want to play with this issue as well as for historical purposes I'm attaching a serialized (pickle) trained XGBoost regression model (trained using the "hist" method) - round_error_xgboost.bin.gz.
Note the
This is where the generated code follows the "yes" path while XGBoost does the opposite. |
Ah, brilliant investigation! BTW, doesn't this code contradict the following from one of treelite's issues? m2cgen/m2cgen/assemblers/tree.py Lines 54 to 57 in 52c601b
Also refer to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just few minor questions.
m2cgen/assemblers/boosting.py
Outdated
@@ -134,7 +134,7 @@ def _assemble_tree(self, tree): | |||
if "leaf" in tree: | |||
return ast.NumVal(tree["leaf"]) | |||
|
|||
threshold = ast.NumVal(tree["split_condition"]) | |||
threshold = ast.NumVal(np.float32(tree["split_condition"])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a more general solution will be to add an optional dtype
constructor argument? I mean,
class NumVal(NumExpr):
def __init__(self, value, dtype=np.float64):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good idea 👍
CLASSIFICATION, | ||
) | ||
|
||
|
||
def regression_random(model): | ||
def regression_random(model, test_fraction=0.02): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does test_fraction
increase for random datasets allow to reproduce the original issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, unfortunately the default fraction produced way too few samples to be able to reproduce the issue reliably.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it! We definitely need to refactor testing routines to be more tunable, e.g. allow to adjust test_fraction
according to programming language. Refer to #114 (comment).
For the reference: a more explicit proof that thresholds are supposed to be float: Also, it seems that threshold can be integer: |
@StrikerRUS Thanks for all the additional context!
That's rather weird. I clearly remember that scikit-learn used float32 in its tree implementation as well. Perhaps a more recent fix? |
Have no idea... https://github.com/scikit-learn/scikit-learn/blob/38030a00a7f72a3528bd17f2345f34d1344d6d45/sklearn/tree/_tree.pyx#L186 |
@StrikerRUS Got a chance to address your comment only now. Sorry about the delay. |
149b0fa
to
242165e
Compare
No problem at all! 🙂 Seems that linter is unhappy with imported For scikit-learn issue (#188 (comment)) I believe it is better to have a separate PR. |
Fixed the linter error and added a test. Thanks 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for fixing this issue!
Everything looks OK to me.
As it turns out the issue reported in #168 is not unique to the "hist" tree construction algorithm. It seems that with "hist" method the likelihood of reprdocue is much higher due to relying on feature histograms. I was able to reproduce the same discrepancy with non-hist methods on a larger sample of test data.
The issue occurs due to a double precision error and reproduces every time when the feature value matches the split condition in one of the tree's nodes.
Example: feature value =
0.671
, split condition =0.671000004
. When we hit this condition in the generated code the outcome of0.671 < 0.671000004
is "true" (or "yes" branch). While in XGBoost the same condition leads to the "no" branch.After some investigation I noticed that the XGBoost's
DMatrix
forces all values to befloat32
(https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/core.py#L565). At the same time in our assemblers we rely on default 64-bit floats. Forcing the split condition to befloat32
seem to address the issue. At least I couldn't reproduce it so far.