Skip to content
107 changes: 105 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private[ml] object AltDT extends Logging {
parentUID: Option[String] = None): DecisionTreeModel = {
// TODO: Check validity of params
val rootNode = trainImpl(input, strategy)
RandomForest.finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID)
impl.RandomForest.finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID)
}

private[impl] def trainImpl(input: RDD[LabeledPoint], strategy: Strategy): Node = {
Expand Down Expand Up @@ -529,12 +529,115 @@ private[ml] object AltDT extends Logging {
}
}

/**
* Find the best split for an unordered categorical feature at a single node.
*
* Algorithm:
* - Considers all possible subsets (exponentially many)
*
* @param featureIndex Index of feature being split.
* @param values Feature values at this node. Sorted in increasing order.
* @param labels Labels corresponding to values, in the same order.
* @return (best split, statistics for split) If the best split actually puts all instances
* in one leaf node, then it will be set to None. The impurity stats maybe still be
* useful, so they are returned.
*/
private[impl] def chooseUnorderedCategoricalSplit(
featureIndex: Int,
values: Seq[Double],
labels: Seq[Double],
metadata: AltDTMetadata,
featureArity: Int): (Option[Split], ImpurityStats) = ???
featureArity: Int): (Option[Split], ImpurityStats) = {

// Label stats for each category
val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)(
_ => metadata.createImpurityAggregator())
values.zip(labels).foreach { case (cat, label) =>
// NOTE: we assume the values for categorical features are Ints in [0,featureArity)
aggStats(cat.toInt).update(label)
}

// Aggregated statistics for left part of split and entire split.
val leftImpurityAgg = metadata.createImpurityAggregator()
val fullImpurityAgg = metadata.createImpurityAggregator()
aggStats.foreach(fullImpurityAgg.add)
val fullImpurity = fullImpurityAgg.getCalculator.calculate()

if (featureArity == 1) {
// All instances go right
val impurityStats = new ImpurityStats(0.0, fullImpurityAgg.getCalculator.calculate(),
fullImpurityAgg.getCalculator, leftImpurityAgg.getCalculator,
fullImpurityAgg.getCalculator)
(None, impurityStats)
} else {
// TODO: We currently add and remove the stats for all categories for each split.
// A better way to do it would be to consider splits in an order such that each iteration
// only requires addition/removal of a single category and a single add/subtract to
// leftCount and rightCount.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note "gray codes" here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

// TODO: Use more efficient encoding such as gray codes
val splits: Array[CategoricalSplit] = findSplits(featureIndex, featureArity, metadata)
var bestSplit: Option[CategoricalSplit] = None
val bestLeftImpurityAgg = leftImpurityAgg.deepCopy()
var bestGain: Double = -1.0
val fullCount: Double = values.size
for (split <- splits) {
// Update left, right impurity stats
split.leftCategories.foreach(c => leftImpurityAgg.add(aggStats(c.toInt)))
val rightImpurityAgg = fullImpurityAgg.subtract(leftImpurityAgg)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modifies fullImpurityAgg. It should create a copy first (or share 1 copy for the whole loop and overwrite on each iteration).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

val leftCount = leftImpurityAgg.getCount
val rightCount = rightImpurityAgg.getCount
// Compute impurity
val leftWeight = leftCount / fullCount
val rightWeight = rightCount / fullCount
val leftImpurity = leftImpurityAgg.getCalculator.calculate()
val rightImpurity = rightImpurityAgg.getCalculator.calculate()
val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
if (gain > bestGain && gain > metadata.minInfoGain) {
bestSplit = Some(split)
leftImpurityAgg.stats.copyToArray(bestLeftImpurityAgg.stats)
bestGain = gain
}
// Reset left and full impurity stats
rightImpurityAgg.add(leftImpurityAgg)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or did you mean to reset fullImpurityAgg here? Regardless, I prefer keeping a stable copy instead of doing more complicated editing and reseting.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

leftImpurityAgg.clear()
}

val bestFeatureSplit = bestSplit match {
case Some(split) => Some(
new CategoricalSplit(featureIndex, split.leftCategories, featureArity))
case None => None

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove newline

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

}
val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity,
fullImpurityAgg.getCalculator, bestLeftImpurityAgg.getCalculator,
bestRightImpurityAgg.getCalculator)
(bestFeatureSplit, bestImpurityStats)
}
}

/**
* Returns all possible subsets of features for categorical splits.
*/
private def findSplits(
featureIndex: Int,
featureArity: Int,
metadata: AltDTMetadata): Array[CategoricalSplit] = {
// Unordered features
// 2^(featureArity - 1) - 1 combinations
val numSplits = (1 << (featureArity - 1)) - 1
val splits = new Array[CategoricalSplit](numSplits)

var splitIndex = 0
while (splitIndex < numSplits) {
val categories: List[Double] =
RandomForest.extractMultiClassCategories(splitIndex + 1, featureArity)
splits(splitIndex) =
new CategoricalSplit(featureIndex, categories.toArray, featureArity)
splitIndex += 1
}
splits
}

/**
* Choose splitting rule: feature value <= threshold
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,26 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(stats.valid)
}

// test("chooseUnorderedCategoricalSplit: basic case") { }
test("chooseUnorderedCategoricalSplit: basic case") {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add 1 more test, analogous to the others with "return bad split if we should not split" in the test name? That will cover the case where we expect to get valid stats back, even when we do not find a split.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

val featureIndex = 0
val featureArity = 4
val values = Seq(3.0, 1.0, 0.0, 2.0, 2.0)
val labels = Seq(0.0, 0.0, 1.0, 1.0, 1.0)
val impurity = Entropy
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
val (split, stats) = AltDT.chooseUnorderedCategoricalSplit(
featureIndex, values, labels, metadata, featureArity)
split match {
case Some(s: CategoricalSplit) =>
assert(s.featureIndex === featureIndex)
assert(s.leftCategories.toSet === Set(0.0, 2.0))
assert(s.rightCategories.toSet === Set(1.0, 3.0))
// TODO: test correctness of stats
case _ =>
throw new AssertionError(
s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}")
}
}

// test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { }

Expand Down