Skip to content

Conversation

@WeichenXu123
Copy link
Contributor

@WeichenXu123 WeichenXu123 commented Nov 6, 2017

What changes were proposed in this pull request?

We do not need generate all possible splits for unordered features before aggregate,

  • Change mixedBinSeqOp (which running on executor), for each unordered feature, we do the same stat with ordered features. so for unordered features, we only need O(numCategories) space for this feature stat.

  • After driver side get the aggregate result, generate all possible split combinations, and compute the best split.

This will reduce decision tree aggregate size for each unordered feature from O(2^numCategories) to O(numCategories), numCategories is the arity of this unordered feature.

This also reduce the cpu cost in executor side. Reduce time complexity for this unordered feature from O(numPoints * 2^numCategories) to O(numPoints).

This won't increase time complexity for unordered features best split computing in driver side.

How to construct statistics of 2^(numCategories - 1) - 1 splits from separate statistics ofnumCategories ?

I use a recursive function traverseUnorderedSplits to traverse all possible splits, and note that, during the traversing, the statistics for each splits will be accumulated, so it will keep the total time complexity to be O(2^n).
I give an example, suppose unordered feature including [a, b, c, d], for each split, decode it by binary representation, e.g, "1010" represent split "[a, c] vs [b, d]" (the "1" represent allocate the value to left child).
And the combNumber of the split is the number whose binary representation match it. (Note that the binary representation is lower bits start from left).
The combNumber will be used for "pruning". (will explain in the following)

Now we want to traverse all possible splits, all possible binary representation will be 2^numCategories, but, the splits number only have 2^(numCategories - 1) - 1 , because we need to exclude the case binary representation all "0" and all "1", and the remaining splits, we need to cut half of them, because split "[a, c] vs [b, d]" and split "[b, d] vs [a, c]" are equivalent, we only need to traverse one of them. The way I use here to filter out the other half is via the combNumber, we only need to traverse the cases whose combNumber satisfy 1 <= combNumber <= numSplists.

So I design the traverseUnorderedSplits function to do the thing.
First we can look into the dfs function in the traverseUnorderedSplits, it's a recursive function, when running, it will compute the combNumber at the same time. When founding current combNumber exceed numSplits, stop recursion.
In the recursive dfs, it also accumulate the statistics. When generating a new split, for example, "[a, c] vs [b, d]", we need compute the statistics for left child, which equals to stats[a] + stats[c]. But, I do not compute this accumulation after the split generated, it will cause the algorithm time complexity to be O(n*2^n), I accumulate the stat in the recursion process. You can check the param stats of dfs to see how it works.

So, my purpose of why traverseUnorderedSplits is designed in this way, the main two reasons:

  1. Make sure the total time complexity not exceed O(2^n), we need to handle the statistics accumulation carefully.
  2. Make high efficient pruning, by the condition 1 <= combNumber <= numSplists.

How was this patch tested?

UT added.

@SparkQA
Copy link

SparkQA commented Nov 6, 2017

Test build #83481 has finished for PR 19666 at commit e79abfd.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • test(\"Multiclass classification with unordered categorical features: split calculations\")

Copy link
Contributor

@smurching smurching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea but I'm a bit confused by some parts of the code/wondering if there's anything we can simplify - I've left a few questions/comments.

I'm wondering if there's a simple way to iterate over the subsets for a given unordered categorical feature instead of generating the subsets recursively (e.g. by iterating over a gray code). This might simplify the logic in traverseUnorderedSplits; however, I'm fine with the recursive approach if it's well-documented.

val numBins = binAggregates.metadata.numBins(featureIndex)
val featureOffset = binAggregates.getFeatureOffset(featureIndexIdx)

val binStatsArray = Array.tabulate(numBins) { binIndex =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a comment explaining what this is? E.g.:
// Each element of binStatsArray stores pre-computed label statistics for a single bin of the current future

(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
if (gainAndImpurityStats.gain > bestGain) {
bestGain = gainAndImpurityStats.gain
bestSet = set | new BitSet(numBins) // copy set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use set.copy()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class do not support copy

var bestLeftChildStats: ImpurityCalculator = null
var bestRightChildStats: ImpurityCalculator = null

traverseUnorderedSplits[ImpurityCalculator](numBins, null,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a comment explaining what this does? E.g.:
// Computes the best split for the current feature, storing the result across the vars above

categories
}

private[tree] def traverseUnorderedSplits[T](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a docstring for this method, since it's a bit complicated?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, does traverseUnorderedSplits need to take a type parameter / two different closures as method arguments? AFAICT the use of a type parameter/closures here allow us to unit test this functionality on a simple example, but I wonder if we could simplify this somehow.

} else {
subSet.set(binIndex)
val leftChildCombNumber = combNumber + (1 << binIndex)
// pruning: only need combNumber satisfy: 1 <= combNumber <= numSplits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the check if (leftChildCombNumber <= numSplits) helps us ensure that we consider each split only once, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. for example: "00101" and "11010" they're equivalent splits, we should traverse only one of them.
So here I use the condition 1 <= combNumber <= numSplits to do the pruning. It can simply filter out another half splits.

@WeichenXu123
Copy link
Contributor Author

WeichenXu123 commented Nov 7, 2017

@smurching I guess if iterating over gray code will have higher time complexity O(n * 2^n), (Not very sure, maybe there's some high efficient algos?) , the recursive traverse in my PR only need O(2^n).
and , recursive traversing has advantage in "pruning", the pruning condition I mentioned above "1 <= combNumber <= numSplits", recursive traversing can pruning the total subtree, and save about half the time. Can iterating over gray code also do this and at the same time keep complexity O(2^n) ?

}

test("traverseUnorderedSplits") {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since traverseUnorderedSplits is a private method, I wonder whether we can check the unorder splits on DecisonTree directly? For example, create a tiny dataset and generate a shallow tree (depth = 1?). I know the test case is difficult (maybe impossible) to design, however it focuses on behavior instead of implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So how to test all possible splits to make sure the generated splits are all correct ? If tree generated, only best split is remained.

@facaiy
Copy link
Contributor

facaiy commented Nov 7, 2017

I believe that unordered features will benefit a lot from the idea, however I have two questions:

  1. I'm a little confused by 964L in traverseUnorderedSplits. Is it a backtracking algorithm?
dfs(binIndex + 1, combNumber, stats) 
  1. traverseUnorderedSplits succeed in decoupling an abstraction from its implementation. Cool! However, I just wonder whether we can write a simpler function, say, remove seqOp, finalizer, T and collect all logic together in one place?

Anyway, thanks for your good work, @WeichenXu123.

@WeichenXu123
Copy link
Contributor Author

@facaiy Thanks for your review! I put more explanation on the design purpose of traverseUnorderedSplits. But, if you have better solution, no hesitate to tell me!

@WeichenXu123
Copy link
Contributor Author

Also cc @smurching Thanks!

@facaiy
Copy link
Contributor

facaiy commented Nov 9, 2017

Hi, I write a demo with python. I'll be happy if it could be useful.

For N bins, say [x_1, x_2, ..., x_N], since all its splits contain either x_1 or not, so we can choose the half splits which doesn't contain x_1 as left splits.
If I understand it correctly, the left splits are indeed all combinations of the left bins, [x_2, x_3, ... x_N]. The problem can be solved by the backtracking algorithm.

Please correct me if I'm wrong. Thanks very much.

#!/usr/bin/env python

def gen_splits(bins):
    if len(bins) == 1:
        return bins
    results = []
    partial_res = []
    gen_splits_iter(1, bins, partial_res, results)
    return results


def gen_splits_iter(dep, bins, partial_res, results):
    if partial_res:
        left_splits = partial_res[:]
        right_splits = [x for x in bins if x not in left_splits]
        results.append("left: {:20}, right: {}".format(str(left_splits), right_splits))

    for m in range(dep, len(bins)):
        partial_res.append(bins[m])
        gen_splits_iter(m+1, bins, partial_res, results)
        partial_res.pop()


if __name__ == "__main__":
    print("first example:")
    bins = ["a", "b", "c"]
    print("bins: {}\n-----".format(bins))
    splits = gen_splits(bins)
    for s in splits:
        print(s)

    print("\n\n=============")
    print("second example:")
    bins = ["a", "b", "c", "d", "e"]
    print("bins: {}\n-----".format(bins))
    splits = gen_splits(bins)
    for s in splits:
        print(s)

logs:

~/Downloads ❯❯❯ python test.py
first example:
bins: ['a', 'b', 'c']
-----
left: ['b']               , right: ['a', 'c']
left: ['b', 'c']          , right: ['a']
left: ['c']               , right: ['a', 'b']


=============
second example:
bins: ['a', 'b', 'c', 'd', 'e']
-----
left: ['b']               , right: ['a', 'c', 'd', 'e']
left: ['b', 'c']          , right: ['a', 'd', 'e']
left: ['b', 'c', 'd']     , right: ['a', 'e']
left: ['b', 'c', 'd', 'e'], right: ['a']
left: ['b', 'c', 'e']     , right: ['a', 'd']
left: ['b', 'd']          , right: ['a', 'c', 'e']
left: ['b', 'd', 'e']     , right: ['a', 'c']
left: ['b', 'e']          , right: ['a', 'c', 'd']
left: ['c']               , right: ['a', 'b', 'd', 'e']
left: ['c', 'd']          , right: ['a', 'b', 'e']
left: ['c', 'd', 'e']     , right: ['a', 'b']
left: ['c', 'e']          , right: ['a', 'b', 'd']
left: ['d']               , right: ['a', 'b', 'c', 'e']
left: ['d', 'e']          , right: ['a', 'b', 'c']
left: ['e']               , right: ['a', 'b', 'c', 'd']

@facaiy
Copy link
Contributor

facaiy commented Nov 9, 2017

In fact, I'm not sure whether the idea is right, so no hesitate to correct me. I assume the algorithm requires O(2^{N-1}) complexity.

@WeichenXu123
Copy link
Contributor Author

WeichenXu123 commented Nov 9, 2017

@facaiy Your idea looks also reasonable. So we can use the condition "exclude the first bin" to do the pruning (filter out the other half symmetric splits). This condition looks simpler than 1 <= combNumber <= numSplists, Good idea !
And your code use another traverse order, my current PR is also backtracking, with different traverse order, but I think both of them works, and both of their complexity will be O(2^n)

@facaiy
Copy link
Contributor

facaiy commented Nov 9, 2017

Thank you, @WeichenXu123 . You can also use the condition "include the first bin" to filter left splits. Perhaps it is better.

@jkbradley
Copy link
Member

Btw, this is going to conflict with #19433 a lot. @WeichenXu123 and @smurching have you planned for merging one before the other?

@smurching
Copy link
Contributor

smurching commented Nov 14, 2017

Discussed with @jkbradley, I'll split up #19433 so that the parts of it that'd conflict with this PR (refactoring RandomForest.scala into utility classes) can be merged first.

@WeichenXu123
Copy link
Contributor Author

OK. I will waiting @smurching to merge split parts of #19433 get merged first, and then I will update this PR.

// Hack: If a categorical feature has only 1 category, we treat it as continuous.
// TODO(SPARK-9957): Handle this properly by filtering out those features.
if (numCategories > 1) {
// Decide if some categorical features should be treated as unordered features,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change , to .

@WeichenXu123 WeichenXu123 deleted the improve_decision_tree branch April 24, 2019 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants