-
Notifications
You must be signed in to change notification settings - Fork 3
Implements chooseUnorderedCategoricalSplit #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c7f7029
d599440
0f45a4d
555d0f6
8c71d6c
2829936
e432fee
7b431df
b0b0882
215f31f
ee8037a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -529,12 +529,143 @@ private[ml] object AltDT extends Logging { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Find the best split for an unordered categorical feature at a single node. | ||
| * | ||
| * Algorithm: | ||
| * - Considers all possible subsets (possibly exponentially many) | ||
| * | ||
| * @param featureIndex Index of feature being split. | ||
| * @param values Feature values at this node. Sorted in increasing order. | ||
| * @param labels Labels corresponding to values, in the same order. | ||
| * @return (best split, statistics for split) If the best split actually puts all instances | ||
| * in one leaf node, then it will be set to None. The impurity stats maybe still be | ||
| * useful, so they are returned. | ||
| */ | ||
| private[impl] def chooseUnorderedCategoricalSplit( | ||
| featureIndex: Int, | ||
| values: Seq[Double], | ||
| labels: Seq[Double], | ||
| metadata: AltDTMetadata, | ||
| featureArity: Int): (Option[Split], ImpurityStats) = ??? | ||
| featureArity: Int): (Option[Split], ImpurityStats) = { | ||
| val categories: List[Double] = values.toSet.toList | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove. This is only used in the data validity check, and that check should be done at the beginning of learning, if ever.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
|
|
||
| // Label stats for each category | ||
| val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( | ||
| _ => metadata.createImpurityAggregator()) | ||
| require(categories.length <= featureArity, "Got more categories than featureArity") | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd say don't bother with this for now. Later on, we can add a validity check at the beginning of learning.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| values.zip(labels).foreach { case (cat, label) => | ||
| // NOTE: we assume the values for categorical features are Ints in [0,featureArity) | ||
| aggStats(cat.toInt).update(label) | ||
| } | ||
|
|
||
| // Aggregated statistics for left and right parts of split. | ||
| val leftImpurityAgg = metadata.createImpurityAggregator() | ||
| val rightImpurityAgg = metadata.createImpurityAggregator() | ||
|
|
||
| require(featureArity > 0, "Feature arity cannot be negative") | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise, this check should be done at the beginning of learning.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| if (featureArity == 1) { | ||
| val impurityStats = new ImpurityStats(0.0, rightImpurityAgg.getCalculator.calculate(), | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The stats for this node need to be valid, not empty. The stats for children can be arbitrary.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I completely understand, but I think the latest version fixes this by using
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, that should fix it. |
||
| rightImpurityAgg.getCalculator, leftImpurityAgg.getCalculator, | ||
| rightImpurityAgg.getCalculator) | ||
| (None, impurityStats) | ||
| } else { | ||
| // TODO: We currently add and remove the stats for all categories for each split. | ||
| // A better way to do it would be to consider splits in an order such that each iteration | ||
| // only requires addition/removal of a single category and a single add/subtract to | ||
| // leftCount and rightCount. | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note "gray codes" here
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| val splits: Array[CategoricalSplit] = findSplits(featureIndex, featureArity, metadata) | ||
| var bestSplit: Option[CategoricalSplit] = None | ||
| val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() | ||
| var bestGain: Double = -1.0 | ||
| aggStats.foreach(rightImpurityAgg.add) | ||
| val fullImpurity = rightImpurityAgg.getCalculator.calculate() | ||
| aggStats.foreach(rightImpurityAgg.subtract) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead, how about renaming rightImpurityAgg to fullImpurityAgg, and just have it store all stats. Then for each split, you can calculate leftImpurityAgg as you're doing below, and then compute rightImpurityAgg by subtracting the 2 aggregators. Similar amount of code, and a little faster.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| val fullCount: Double = values.size | ||
| for (split <- splits) { | ||
| // Update left, right impurity stats | ||
| split.leftCategories.foreach(c => leftImpurityAgg.add(aggStats(c.toInt))) | ||
| split.rightCategories.foreach(c => rightImpurityAgg.add(aggStats(c.toInt))) | ||
| val leftCount = leftImpurityAgg.getCount | ||
| val rightCount = rightImpurityAgg.getCount | ||
| // Compute impurity | ||
| val leftWeight = leftCount / fullCount | ||
| val rightWeight = rightCount / fullCount | ||
| val leftImpurity = leftImpurityAgg.getCalculator.calculate() | ||
| val rightImpurity = rightImpurityAgg.getCalculator.calculate() | ||
| val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity | ||
| if (gain > bestGain && gain > metadata.minInfoGain) { | ||
| bestSplit = Some(split) | ||
| leftImpurityAgg.stats.copyToArray(bestLeftImpurityAgg.stats) | ||
| bestGain = gain | ||
| } | ||
| // Reset left, right impurity stats | ||
| split.leftCategories.foreach(c => leftImpurityAgg.subtract(aggStats(c.toInt))) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just zero things out; no need for subtraction.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| split.rightCategories.foreach(c => rightImpurityAgg.subtract(aggStats(c.toInt))) | ||
| } | ||
|
|
||
| val bestFeatureSplit = bestSplit match { | ||
| case Some(split) => | ||
| new CategoricalSplit(featureIndex, split.leftCategories, featureArity) | ||
| case None => | ||
| throw new AssertionError("Unknown error in AltDT unordered categorical split selection") | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may not be an unknown error. It will happen if no splits exceed minInfoGain. It should return None for the split.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
|
|
||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove newline
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| } | ||
| val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg) | ||
| val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) | ||
| val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, | ||
| fullImpurityAgg.getCalculator, bestLeftImpurityAgg.getCalculator, | ||
| bestRightImpurityAgg.getCalculator) | ||
| (Some(bestFeatureSplit), bestImpurityStats) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns all possible subsets of features for categorical splits. | ||
| */ | ||
| private def findSplits( | ||
| featureIndex: Int, | ||
| featureArity: Int, | ||
| metadata: AltDTMetadata): Array[CategoricalSplit] = { | ||
| // Unordered features | ||
| // 2^(featureArity - 1) - 1 combinations | ||
| val numSplits = (1 << (featureArity - 1)) - 1 | ||
| val splits = new Array[CategoricalSplit](numSplits) | ||
|
|
||
| var splitIndex = 0 | ||
| while (splitIndex < numSplits) { | ||
| val categories: List[Double] = | ||
| extractMultiClassCategories(splitIndex + 1, featureArity) | ||
| splits(splitIndex) = | ||
| new CategoricalSplit(featureIndex, categories.toArray, featureArity) | ||
| splitIndex += 1 | ||
| } | ||
| splits | ||
| } | ||
|
|
||
| /** | ||
| * Nested method to extract list of eligible categories given an index. It extracts the | ||
| * position of ones in a binary representation of the input. If binary | ||
| * representation of an number is 01101 (13), the output list should (3.0, 2.0, | ||
| * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. | ||
| */ | ||
| private def extractMultiClassCategories( | ||
|
||
| input: Int, | ||
| maxFeatureValue: Int): List[Double] = { | ||
| var categories = List[Double]() | ||
| var j = 0 | ||
| var bitShiftedInput = input | ||
| while (j < maxFeatureValue) { | ||
| if (bitShiftedInput % 2 != 0) { | ||
| // updating the list of categories. | ||
| categories = j.toDouble :: categories | ||
| } | ||
| // Right shift by one | ||
| bitShiftedInput = bitShiftedInput >> 1 | ||
| j += 1 | ||
| } | ||
| categories | ||
| } | ||
|
|
||
| /** | ||
| * Choose splitting rule: feature value <= threshold | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -231,7 +231,26 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| assert(stats.valid) | ||
| } | ||
|
|
||
| // test("chooseUnorderedCategoricalSplit: basic case") { } | ||
| test("chooseUnorderedCategoricalSplit: basic case") { | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add 1 more test, analogous to the others with "return bad split if we should not split" in the test name? That will cover the case where we expect to get valid stats back, even when we do not find a split.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
| val featureIndex = 0 | ||
| val featureArity = 4 | ||
| val values = Seq(3.0, 1.0, 0.0, 2.0, 2.0) | ||
| val labels = Seq(0.0, 0.0, 1.0, 1.0, 1.0) | ||
| val impurity = Entropy | ||
| val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity) | ||
| val (split, stats) = AltDT.chooseUnorderedCategoricalSplit( | ||
| featureIndex, values, labels, metadata, featureArity) | ||
| split match { | ||
| case Some(s: CategoricalSplit) => | ||
| assert(s.featureIndex === featureIndex) | ||
| assert(s.leftCategories.toSet === Set(0.0, 2.0)) | ||
| assert(s.rightCategories.toSet === Set(1.0, 3.0)) | ||
| // TODO: test correctness of stats | ||
| case _ => | ||
| throw new AssertionError( | ||
| s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") | ||
| } | ||
| } | ||
|
|
||
| // test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove "possibly"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK