Skip to content

Commit 1ed14f5

Browse files
committed
Merge pull request #12 from feynmanliang/dt-unordered-categorical
Implements chooseUnorderedCategoricalSplit
2 parents bec5565 + ee8037a commit 1ed14f5

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private[ml] object AltDT extends Logging {
102102
parentUID: Option[String] = None): DecisionTreeModel = {
103103
// TODO: Check validity of params
104104
val rootNode = trainImpl(input, strategy)
105-
RandomForest.finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID)
105+
impl.RandomForest.finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID)
106106
}
107107

108108
private[impl] def trainImpl(input: RDD[LabeledPoint], strategy: Strategy): Node = {
@@ -529,12 +529,114 @@ private[ml] object AltDT extends Logging {
529529
}
530530
}
531531

532+
/**
533+
* Find the best split for an unordered categorical feature at a single node.
534+
*
535+
* Algorithm:
536+
* - Considers all possible subsets (exponentially many)
537+
*
538+
* @param featureIndex Index of feature being split.
539+
* @param values Feature values at this node. Sorted in increasing order.
540+
* @param labels Labels corresponding to values, in the same order.
541+
* @return (best split, statistics for split) If the best split actually puts all instances
542+
* in one leaf node, then it will be set to None. The impurity stats maybe still be
543+
* useful, so they are returned.
544+
*/
532545
private[impl] def chooseUnorderedCategoricalSplit(
533546
featureIndex: Int,
534547
values: Seq[Double],
535548
labels: Seq[Double],
536549
metadata: AltDTMetadata,
537-
featureArity: Int): (Option[Split], ImpurityStats) = ???
550+
featureArity: Int): (Option[Split], ImpurityStats) = {
551+
552+
// Label stats for each category
553+
val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)(
554+
_ => metadata.createImpurityAggregator())
555+
values.zip(labels).foreach { case (cat, label) =>
556+
// NOTE: we assume the values for categorical features are Ints in [0,featureArity)
557+
aggStats(cat.toInt).update(label)
558+
}
559+
560+
// Aggregated statistics for left part of split and entire split.
561+
val leftImpurityAgg = metadata.createImpurityAggregator()
562+
val fullImpurityAgg = metadata.createImpurityAggregator()
563+
aggStats.foreach(fullImpurityAgg.add)
564+
val fullImpurity = fullImpurityAgg.getCalculator.calculate()
565+
566+
if (featureArity == 1) {
567+
// All instances go right
568+
val impurityStats = new ImpurityStats(0.0, fullImpurityAgg.getCalculator.calculate(),
569+
fullImpurityAgg.getCalculator, leftImpurityAgg.getCalculator,
570+
fullImpurityAgg.getCalculator)
571+
(None, impurityStats)
572+
} else {
573+
// TODO: We currently add and remove the stats for all categories for each split.
574+
// A better way to do it would be to consider splits in an order such that each iteration
575+
// only requires addition/removal of a single category and a single add/subtract to
576+
// leftCount and rightCount.
577+
// TODO: Use more efficient encoding such as gray codes
578+
val splits: Array[CategoricalSplit] = findSplits(featureIndex, featureArity, metadata)
579+
var bestSplit: Option[CategoricalSplit] = None
580+
val bestLeftImpurityAgg = leftImpurityAgg.deepCopy()
581+
var bestGain: Double = -1.0
582+
val fullCount: Double = values.size
583+
for (split <- splits) {
584+
// Update left, right impurity stats
585+
split.leftCategories.foreach(c => leftImpurityAgg.add(aggStats(c.toInt)))
586+
val rightImpurityAgg = fullImpurityAgg.deepCopy().subtract(leftImpurityAgg)
587+
val leftCount = leftImpurityAgg.getCount
588+
val rightCount = rightImpurityAgg.getCount
589+
// Compute impurity
590+
val leftWeight = leftCount / fullCount
591+
val rightWeight = rightCount / fullCount
592+
val leftImpurity = leftImpurityAgg.getCalculator.calculate()
593+
val rightImpurity = rightImpurityAgg.getCalculator.calculate()
594+
val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
595+
if (gain > bestGain && gain > metadata.minInfoGain) {
596+
bestSplit = Some(split)
597+
leftImpurityAgg.stats.copyToArray(bestLeftImpurityAgg.stats)
598+
bestGain = gain
599+
}
600+
// Reset left impurity stats
601+
leftImpurityAgg.clear()
602+
}
603+
604+
val bestFeatureSplit = bestSplit match {
605+
case Some(split) => Some(
606+
new CategoricalSplit(featureIndex, split.leftCategories, featureArity))
607+
case None => None
608+
609+
}
610+
val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
611+
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity,
612+
fullImpurityAgg.getCalculator, bestLeftImpurityAgg.getCalculator,
613+
bestRightImpurityAgg.getCalculator)
614+
(bestFeatureSplit, bestImpurityStats)
615+
}
616+
}
617+
618+
/**
619+
* Returns all possible subsets of features for categorical splits.
620+
*/
621+
private def findSplits(
622+
featureIndex: Int,
623+
featureArity: Int,
624+
metadata: AltDTMetadata): Array[CategoricalSplit] = {
625+
// Unordered features
626+
// 2^(featureArity - 1) - 1 combinations
627+
val numSplits = (1 << (featureArity - 1)) - 1
628+
val splits = new Array[CategoricalSplit](numSplits)
629+
630+
var splitIndex = 0
631+
while (splitIndex < numSplits) {
632+
val categories: List[Double] =
633+
RandomForest.extractMultiClassCategories(splitIndex + 1, featureArity)
634+
splits(splitIndex) =
635+
new CategoricalSplit(featureIndex, categories.toArray, featureArity)
636+
splitIndex += 1
637+
}
638+
splits
639+
}
538640

539641
/**
540642
* Choose splitting rule: feature value <= threshold

mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,45 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
231231
assert(stats.valid)
232232
}
233233

234-
// test("chooseUnorderedCategoricalSplit: basic case") { }
234+
test("chooseUnorderedCategoricalSplit: basic case") {
235+
val featureIndex = 0
236+
val featureArity = 4
237+
val values = Seq(3.0, 1.0, 0.0, 2.0, 2.0)
238+
val labels = Seq(0.0, 0.0, 1.0, 1.0, 1.0)
239+
val impurity = Entropy
240+
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
241+
val (split, stats) = AltDT.chooseUnorderedCategoricalSplit(
242+
featureIndex, values, labels, metadata, featureArity)
243+
split match {
244+
case Some(s: CategoricalSplit) =>
245+
assert(s.featureIndex === featureIndex)
246+
assert(s.leftCategories.toSet === Set(0.0, 2.0))
247+
assert(s.rightCategories.toSet === Set(1.0, 3.0))
248+
// TODO: test correctness of stats
249+
case _ =>
250+
throw new AssertionError(
251+
s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}")
252+
}
253+
}
235254

236-
// test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { }
255+
test("chooseUnorderedCategoricalSplit: return bad split if we should not split") {
256+
val featureIndex = 0
257+
val featureArity = 4
258+
val values = Seq(3.0, 1.0, 0.0, 2.0, 2.0)
259+
val labels = Seq(1.0, 1.0, 1.0, 1.0, 1.0)
260+
val impurity = Entropy
261+
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
262+
val (split, stats) =
263+
AltDT.chooseOrderedCategoricalSplit(featureIndex, values, labels, metadata, featureArity)
264+
assert(split.isEmpty)
265+
val fullImpurityStatsArray =
266+
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
267+
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
268+
assert(stats.gain === 0.0)
269+
assert(stats.impurity === fullImpurity)
270+
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
271+
assert(stats.valid)
272+
}
237273

238274
test("chooseContinuousSplit: basic case") {
239275
val featureIndex = 0

0 commit comments

Comments
 (0)