-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training #2708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
QA tests have started for PR 2708 at commit
|
|
QA tests have finished for PR 2708 at commit
|
|
Test FAILed. |
|
QA tests have started for PR 2708 at commit
|
|
QA tests have finished for PR 2708 at commit
|
|
Test PASSed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also check stats.leftImpurity and rightImpurity. If stats.leftImpurity = 0, then we know the left child will be a leaf. Same for the right child.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once this is done, it might be good to add 1 more test to make sure it works. A slight modification of the test you already added should work.
|
@chouqin Thank you for the PR! It looks almost ready, except for the items I noted above. I will try some timing tests and post results here when done. |
|
QA tests have started for PR 2708 at commit
|
|
@jkbradley thanks for your comments, I have adjusted the code accordingly. I look forward to your timing test and hope that it will get some performance gain. |
|
QA tests have finished for PR 2708 at commit
|
|
Test PASSed. |
|
@chouqin Thanks for the updates! LGTM. I ran some timing tests and saw consistent speedups, but the speedups were not as big as I would have expected. (It was about 1.1X or 1.2X faster, but I would hope for close to 2X faster.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val nodes = new Array[Node](numNodes)
|
Test FAILed. |
|
test this please |
|
QA tests have started for PR 2708 at commit
|
|
QA tests have finished for PR 2708 at commit
|
|
Test PASSed. |
|
Merged into master. Thanks! |
Currently, the implementation does one unnecessary aggregation step. The aggregation step for level L (to choose splits) gives enough information to set the predictions of any leaf nodes at level L+1. We can use that info and skip the aggregation step for the last level of the tree (which only has leaf nodes).
Implementation Details
Each node now has a
impurityfield and thepredictis changed from typeDoubleto typePredict(this can be used to compute predict probability in the future) When compute best splits for each node, we also compute impurity and predict for the child nodes, which is used to constructed newly allocated child nodes. So at level L, we have set impurity and predict for nodes at level L +1.If level L+1 is the last level, then we can avoid aggregation. What's more, calculation of parent impurity in
Top nodes for each tree needs to be treated differently because we have to compute impurity and predict for them first. In
binsToBestSplit, if current node is top node(level == 0), we calculate impurity and predict first.after finding best split, top node's predict and impurity is set to the calculated value. Non-top nodes's impurity and predict are already calculated and don't need to be recalculated again. I have considered to add a initialization step to set top nodes' impurity and predict and then we can treat all nodes in the same way, but this will need a lot of duplication of code(all the code to do seq operation(BinSeqOp) needs to be duplicated), so I choose the current way.
CC @mengxr @manishamde @jkbradley, please help me review this, thanks.