-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22451][ML] Reduce decision tree aggregate size for unordered features from O(2^numCategories) to O(numCategories) #19666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #83481 has finished for PR 19666 at commit
|
smurching
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this idea but I'm a bit confused by some parts of the code/wondering if there's anything we can simplify - I've left a few questions/comments.
I'm wondering if there's a simple way to iterate over the subsets for a given unordered categorical feature instead of generating the subsets recursively (e.g. by iterating over a gray code). This might simplify the logic in traverseUnorderedSplits; however, I'm fine with the recursive approach if it's well-documented.
| val numBins = binAggregates.metadata.numBins(featureIndex) | ||
| val featureOffset = binAggregates.getFeatureOffset(featureIndexIdx) | ||
|
|
||
| val binStatsArray = Array.tabulate(numBins) { binIndex => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a comment explaining what this is? E.g.:
// Each element of binStatsArray stores pre-computed label statistics for a single bin of the current future
| (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) | ||
| if (gainAndImpurityStats.gain > bestGain) { | ||
| bestGain = gainAndImpurityStats.gain | ||
| bestSet = set | new BitSet(numBins) // copy set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use set.copy()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class do not support copy
| var bestLeftChildStats: ImpurityCalculator = null | ||
| var bestRightChildStats: ImpurityCalculator = null | ||
|
|
||
| traverseUnorderedSplits[ImpurityCalculator](numBins, null, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a comment explaining what this does? E.g.:
// Computes the best split for the current feature, storing the result across the vars above
| categories | ||
| } | ||
|
|
||
| private[tree] def traverseUnorderedSplits[T]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a docstring for this method, since it's a bit complicated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, does traverseUnorderedSplits need to take a type parameter / two different closures as method arguments? AFAICT the use of a type parameter/closures here allow us to unit test this functionality on a simple example, but I wonder if we could simplify this somehow.
| } else { | ||
| subSet.set(binIndex) | ||
| val leftChildCombNumber = combNumber + (1 << binIndex) | ||
| // pruning: only need combNumber satisfy: 1 <= combNumber <= numSplits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the check if (leftChildCombNumber <= numSplits) helps us ensure that we consider each split only once, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. for example: "00101" and "11010" they're equivalent splits, we should traverse only one of them.
So here I use the condition 1 <= combNumber <= numSplits to do the pruning. It can simply filter out another half splits.
|
@smurching I guess if iterating over gray code will have higher time complexity O(n * 2^n), (Not very sure, maybe there's some high efficient algos?) , the recursive traverse in my PR only need O(2^n). |
| } | ||
|
|
||
| test("traverseUnorderedSplits") { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since traverseUnorderedSplits is a private method, I wonder whether we can check the unorder splits on DecisonTree directly? For example, create a tiny dataset and generate a shallow tree (depth = 1?). I know the test case is difficult (maybe impossible) to design, however it focuses on behavior instead of implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So how to test all possible splits to make sure the generated splits are all correct ? If tree generated, only best split is remained.
|
I believe that unordered features will benefit a lot from the idea, however I have two questions:
dfs(binIndex + 1, combNumber, stats)
Anyway, thanks for your good work, @WeichenXu123. |
|
@facaiy Thanks for your review! I put more explanation on the design purpose of |
|
Also cc @smurching Thanks! |
|
Hi, I write a demo with python. I'll be happy if it could be useful. For N bins, say Please correct me if I'm wrong. Thanks very much. #!/usr/bin/env python
def gen_splits(bins):
if len(bins) == 1:
return bins
results = []
partial_res = []
gen_splits_iter(1, bins, partial_res, results)
return results
def gen_splits_iter(dep, bins, partial_res, results):
if partial_res:
left_splits = partial_res[:]
right_splits = [x for x in bins if x not in left_splits]
results.append("left: {:20}, right: {}".format(str(left_splits), right_splits))
for m in range(dep, len(bins)):
partial_res.append(bins[m])
gen_splits_iter(m+1, bins, partial_res, results)
partial_res.pop()
if __name__ == "__main__":
print("first example:")
bins = ["a", "b", "c"]
print("bins: {}\n-----".format(bins))
splits = gen_splits(bins)
for s in splits:
print(s)
print("\n\n=============")
print("second example:")
bins = ["a", "b", "c", "d", "e"]
print("bins: {}\n-----".format(bins))
splits = gen_splits(bins)
for s in splits:
print(s)logs: ~/Downloads ❯❯❯ python test.py
first example:
bins: ['a', 'b', 'c']
-----
left: ['b'] , right: ['a', 'c']
left: ['b', 'c'] , right: ['a']
left: ['c'] , right: ['a', 'b']
=============
second example:
bins: ['a', 'b', 'c', 'd', 'e']
-----
left: ['b'] , right: ['a', 'c', 'd', 'e']
left: ['b', 'c'] , right: ['a', 'd', 'e']
left: ['b', 'c', 'd'] , right: ['a', 'e']
left: ['b', 'c', 'd', 'e'], right: ['a']
left: ['b', 'c', 'e'] , right: ['a', 'd']
left: ['b', 'd'] , right: ['a', 'c', 'e']
left: ['b', 'd', 'e'] , right: ['a', 'c']
left: ['b', 'e'] , right: ['a', 'c', 'd']
left: ['c'] , right: ['a', 'b', 'd', 'e']
left: ['c', 'd'] , right: ['a', 'b', 'e']
left: ['c', 'd', 'e'] , right: ['a', 'b']
left: ['c', 'e'] , right: ['a', 'b', 'd']
left: ['d'] , right: ['a', 'b', 'c', 'e']
left: ['d', 'e'] , right: ['a', 'b', 'c']
left: ['e'] , right: ['a', 'b', 'c', 'd'] |
|
In fact, I'm not sure whether the idea is right, so no hesitate to correct me. I assume the algorithm requires O(2^{N-1}) complexity. |
|
@facaiy Your idea looks also reasonable. So we can use the condition "exclude the first bin" to do the pruning (filter out the other half symmetric splits). This condition looks simpler than |
|
Thank you, @WeichenXu123 . You can also use the condition "include the first bin" to filter left splits. Perhaps it is better. |
|
Btw, this is going to conflict with #19433 a lot. @WeichenXu123 and @smurching have you planned for merging one before the other? |
|
Discussed with @jkbradley, I'll split up #19433 so that the parts of it that'd conflict with this PR (refactoring RandomForest.scala into utility classes) can be merged first. |
|
OK. I will waiting @smurching to merge split parts of #19433 get merged first, and then I will update this PR. |
| // Hack: If a categorical feature has only 1 category, we treat it as continuous. | ||
| // TODO(SPARK-9957): Handle this properly by filtering out those features. | ||
| if (numCategories > 1) { | ||
| // Decide if some categorical features should be treated as unordered features, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change , to .
What changes were proposed in this pull request?
We do not need generate all possible splits for unordered features before aggregate,
Change
mixedBinSeqOp(which running on executor), for each unordered feature, we do the same stat with ordered features. so for unordered features, we only need O(numCategories) space for this feature stat.After driver side get the aggregate result, generate all possible split combinations, and compute the best split.
This will reduce decision tree aggregate size for each unordered feature from
O(2^numCategories)toO(numCategories),numCategoriesis the arity of this unordered feature.This also reduce the cpu cost in executor side. Reduce time complexity for this unordered feature from
O(numPoints * 2^numCategories)toO(numPoints).This won't increase time complexity for unordered features best split computing in driver side.
How to construct statistics of
2^(numCategories - 1) - 1splits from separate statistics ofnumCategories?I use a recursive function
traverseUnorderedSplitsto traverse all possible splits, and note that, during the traversing, the statistics for each splits will be accumulated, so it will keep the total time complexity to be O(2^n).I give an example, suppose unordered feature including [a, b, c, d], for each split, decode it by binary representation, e.g, "1010" represent split "[a, c] vs [b, d]" (the "1" represent allocate the value to left child).
And the
combNumberof the split is the number whose binary representation match it. (Note that the binary representation is lower bits start from left).The
combNumberwill be used for "pruning". (will explain in the following)Now we want to traverse all possible splits, all possible binary representation will be
2^numCategories, but, the splits number only have2^(numCategories - 1) - 1, because we need to exclude the case binary representation all "0" and all "1", and the remaining splits, we need to cut half of them, because split "[a, c] vs [b, d]" and split "[b, d] vs [a, c]" are equivalent, we only need to traverse one of them. The way I use here to filter out the other half is via thecombNumber, we only need to traverse the cases whosecombNumbersatisfy1 <= combNumber <= numSplists.So I design the
traverseUnorderedSplitsfunction to do the thing.First we can look into the
dfsfunction in thetraverseUnorderedSplits, it's a recursive function, when running, it will compute thecombNumberat the same time. When founding currentcombNumberexceednumSplits, stop recursion.In the recursive
dfs, it also accumulate the statistics. When generating a new split, for example, "[a, c] vs [b, d]", we need compute the statistics for left child, which equals tostats[a] + stats[c]. But, I do not compute this accumulation after the split generated, it will cause the algorithm time complexity to beO(n*2^n), I accumulate the stat in the recursion process. You can check the paramstatsofdfsto see how it works.So, my purpose of why
traverseUnorderedSplitsis designed in this way, the main two reasons:1 <= combNumber <= numSplists.How was this patch tested?
UT added.