-
Notifications
You must be signed in to change notification settings - Fork 3
Implements chooseUnorderedCategoricalSplit #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
c7f7029
d599440
0f45a4d
555d0f6
8c71d6c
2829936
e432fee
7b431df
b0b0882
215f31f
ee8037a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -102,7 +102,7 @@ private[ml] object AltDT extends Logging { | |
| parentUID: Option[String] = None): DecisionTreeModel = { | ||
| // TODO: Check validity of params | ||
| val rootNode = trainImpl(input, strategy) | ||
| RandomForest.finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID) | ||
| impl.RandomForest.finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID) | ||
| } | ||
|
|
||
| private[impl] def trainImpl(input: RDD[LabeledPoint], strategy: Strategy): Node = { | ||
|
|
@@ -529,12 +529,115 @@ private[ml] object AltDT extends Logging { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Find the best split for an unordered categorical feature at a single node. | ||
| * | ||
| * Algorithm: | ||
| * - Considers all possible subsets (exponentially many) | ||
| * | ||
| * @param featureIndex Index of feature being split. | ||
| * @param values Feature values at this node. Sorted in increasing order. | ||
| * @param labels Labels corresponding to values, in the same order. | ||
| * @return (best split, statistics for split) If the best split actually puts all instances | ||
| * in one leaf node, then it will be set to None. The impurity stats maybe still be | ||
| * useful, so they are returned. | ||
| */ | ||
| private[impl] def chooseUnorderedCategoricalSplit( | ||
| featureIndex: Int, | ||
| values: Seq[Double], | ||
| labels: Seq[Double], | ||
| metadata: AltDTMetadata, | ||
| featureArity: Int): (Option[Split], ImpurityStats) = ??? | ||
| featureArity: Int): (Option[Split], ImpurityStats) = { | ||
|
|
||
| // Label stats for each category | ||
| val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( | ||
| _ => metadata.createImpurityAggregator()) | ||
| values.zip(labels).foreach { case (cat, label) => | ||
| // NOTE: we assume the values for categorical features are Ints in [0,featureArity) | ||
| aggStats(cat.toInt).update(label) | ||
| } | ||
|
|
||
| // Aggregated statistics for left part of split and entire split. | ||
| val leftImpurityAgg = metadata.createImpurityAggregator() | ||
| val fullImpurityAgg = metadata.createImpurityAggregator() | ||
| aggStats.foreach(fullImpurityAgg.add) | ||
| val fullImpurity = fullImpurityAgg.getCalculator.calculate() | ||
|
|
||
| if (featureArity == 1) { | ||
| // All instances go right | ||
| val impurityStats = new ImpurityStats(0.0, fullImpurityAgg.getCalculator.calculate(), | ||
| fullImpurityAgg.getCalculator, leftImpurityAgg.getCalculator, | ||
| fullImpurityAgg.getCalculator) | ||
| (None, impurityStats) | ||
| } else { | ||
| // TODO: We currently add and remove the stats for all categories for each split. | ||
| // A better way to do it would be to consider splits in an order such that each iteration | ||
| // only requires addition/removal of a single category and a single add/subtract to | ||
| // leftCount and rightCount. | ||
| // TODO: Use more efficient encoding such as gray codes | ||
| val splits: Array[CategoricalSplit] = findSplits(featureIndex, featureArity, metadata) | ||
| var bestSplit: Option[CategoricalSplit] = None | ||
| val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() | ||
| var bestGain: Double = -1.0 | ||
| val fullCount: Double = values.size | ||
| for (split <- splits) { | ||
| // Update left, right impurity stats | ||
| split.leftCategories.foreach(c => leftImpurityAgg.add(aggStats(c.toInt))) | ||
| val rightImpurityAgg = fullImpurityAgg.subtract(leftImpurityAgg) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This modifies fullImpurityAgg. It should create a copy first (or share 1 copy for the whole loop and overwrite on each iteration).
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| val leftCount = leftImpurityAgg.getCount | ||
| val rightCount = rightImpurityAgg.getCount | ||
| // Compute impurity | ||
| val leftWeight = leftCount / fullCount | ||
| val rightWeight = rightCount / fullCount | ||
| val leftImpurity = leftImpurityAgg.getCalculator.calculate() | ||
| val rightImpurity = rightImpurityAgg.getCalculator.calculate() | ||
| val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity | ||
| if (gain > bestGain && gain > metadata.minInfoGain) { | ||
| bestSplit = Some(split) | ||
| leftImpurityAgg.stats.copyToArray(bestLeftImpurityAgg.stats) | ||
| bestGain = gain | ||
| } | ||
| // Reset left and full impurity stats | ||
| rightImpurityAgg.add(leftImpurityAgg) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or did you mean to reset fullImpurityAgg here? Regardless, I prefer keeping a stable copy instead of doing more complicated editing and reseting.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| leftImpurityAgg.clear() | ||
| } | ||
|
|
||
| val bestFeatureSplit = bestSplit match { | ||
| case Some(split) => Some( | ||
| new CategoricalSplit(featureIndex, split.leftCategories, featureArity)) | ||
| case None => None | ||
|
|
||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove newline
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| } | ||
| val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) | ||
| val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, | ||
| fullImpurityAgg.getCalculator, bestLeftImpurityAgg.getCalculator, | ||
| bestRightImpurityAgg.getCalculator) | ||
| (bestFeatureSplit, bestImpurityStats) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns all possible subsets of features for categorical splits. | ||
| */ | ||
| private def findSplits( | ||
| featureIndex: Int, | ||
| featureArity: Int, | ||
| metadata: AltDTMetadata): Array[CategoricalSplit] = { | ||
| // Unordered features | ||
| // 2^(featureArity - 1) - 1 combinations | ||
| val numSplits = (1 << (featureArity - 1)) - 1 | ||
| val splits = new Array[CategoricalSplit](numSplits) | ||
|
|
||
| var splitIndex = 0 | ||
| while (splitIndex < numSplits) { | ||
| val categories: List[Double] = | ||
| RandomForest.extractMultiClassCategories(splitIndex + 1, featureArity) | ||
| splits(splitIndex) = | ||
| new CategoricalSplit(featureIndex, categories.toArray, featureArity) | ||
| splitIndex += 1 | ||
| } | ||
| splits | ||
| } | ||
|
|
||
| /** | ||
| * Choose splitting rule: feature value <= threshold | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -231,7 +231,26 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| assert(stats.valid) | ||
| } | ||
|
|
||
| // test("chooseUnorderedCategoricalSplit: basic case") { } | ||
| test("chooseUnorderedCategoricalSplit: basic case") { | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add 1 more test, analogous to the others with "return bad split if we should not split" in the test name? That will cover the case where we expect to get valid stats back, even when we do not find a split.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| val featureIndex = 0 | ||
| val featureArity = 4 | ||
| val values = Seq(3.0, 1.0, 0.0, 2.0, 2.0) | ||
| val labels = Seq(0.0, 0.0, 1.0, 1.0, 1.0) | ||
| val impurity = Entropy | ||
| val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity) | ||
| val (split, stats) = AltDT.chooseUnorderedCategoricalSplit( | ||
| featureIndex, values, labels, metadata, featureArity) | ||
| split match { | ||
| case Some(s: CategoricalSplit) => | ||
| assert(s.featureIndex === featureIndex) | ||
| assert(s.leftCategories.toSet === Set(0.0, 2.0)) | ||
| assert(s.rightCategories.toSet === Set(1.0, 3.0)) | ||
| // TODO: test correctness of stats | ||
| case _ => | ||
| throw new AssertionError( | ||
| s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") | ||
| } | ||
| } | ||
|
|
||
| // test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note "gray codes" here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK