Skip to content
133 changes: 132 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,143 @@ private[ml] object AltDT extends Logging {
}
}

/**
* Find the best split for an unordered categorical feature at a single node.
*
* Algorithm:
* - Considers all possible subsets (possibly exponentially many)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove "possibly"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

*
* @param featureIndex Index of feature being split.
* @param values Feature values at this node. Sorted in increasing order.
* @param labels Labels corresponding to values, in the same order.
* @return (best split, statistics for split) If the best split actually puts all instances
* in one leaf node, then it will be set to None. The impurity stats maybe still be
* useful, so they are returned.
*/
private[impl] def chooseUnorderedCategoricalSplit(
featureIndex: Int,
values: Seq[Double],
labels: Seq[Double],
metadata: AltDTMetadata,
featureArity: Int): (Option[Split], ImpurityStats) = ???
featureArity: Int): (Option[Split], ImpurityStats) = {
val categories: List[Double] = values.toSet.toList
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove. This is only used in the data validity check, and that check should be done at the beginning of learning, if ever.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


// Label stats for each category
val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)(
_ => metadata.createImpurityAggregator())
require(categories.length <= featureArity, "Got more categories than featureArity")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say don't bother with this for now. Later on, we can add a validity check at the beginning of learning.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

values.zip(labels).foreach { case (cat, label) =>
// NOTE: we assume the values for categorical features are Ints in [0,featureArity)
aggStats(cat.toInt).update(label)
}

// Aggregated statistics for left and right parts of split.
val leftImpurityAgg = metadata.createImpurityAggregator()
val rightImpurityAgg = metadata.createImpurityAggregator()

require(featureArity > 0, "Feature arity cannot be negative")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, this check should be done at the beginning of learning.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

if (featureArity == 1) {
val impurityStats = new ImpurityStats(0.0, rightImpurityAgg.getCalculator.calculate(),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stats for this node need to be valid, not empty. The stats for children can be arbitrary.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I completely understand, but I think the latest version fixes this by using fullImpurityAgg here

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that should fix it.

rightImpurityAgg.getCalculator, leftImpurityAgg.getCalculator,
rightImpurityAgg.getCalculator)
(None, impurityStats)
} else {
// TODO: We currently add and remove the stats for all categories for each split.
// A better way to do it would be to consider splits in an order such that each iteration
// only requires addition/removal of a single category and a single add/subtract to
// leftCount and rightCount.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note "gray codes" here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

val splits: Array[CategoricalSplit] = findSplits(featureIndex, featureArity, metadata)
var bestSplit: Option[CategoricalSplit] = None
val bestLeftImpurityAgg = leftImpurityAgg.deepCopy()
var bestGain: Double = -1.0
aggStats.foreach(rightImpurityAgg.add)
val fullImpurity = rightImpurityAgg.getCalculator.calculate()
aggStats.foreach(rightImpurityAgg.subtract)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, how about renaming rightImpurityAgg to fullImpurityAgg, and just have it store all stats. Then for each split, you can calculate leftImpurityAgg as you're doing below, and then compute rightImpurityAgg by subtracting the 2 aggregators. Similar amount of code, and a little faster.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

val fullCount: Double = values.size
for (split <- splits) {
// Update left, right impurity stats
split.leftCategories.foreach(c => leftImpurityAgg.add(aggStats(c.toInt)))
split.rightCategories.foreach(c => rightImpurityAgg.add(aggStats(c.toInt)))
val leftCount = leftImpurityAgg.getCount
val rightCount = rightImpurityAgg.getCount
// Compute impurity
val leftWeight = leftCount / fullCount
val rightWeight = rightCount / fullCount
val leftImpurity = leftImpurityAgg.getCalculator.calculate()
val rightImpurity = rightImpurityAgg.getCalculator.calculate()
val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
if (gain > bestGain && gain > metadata.minInfoGain) {
bestSplit = Some(split)
leftImpurityAgg.stats.copyToArray(bestLeftImpurityAgg.stats)
bestGain = gain
}
// Reset left, right impurity stats
split.leftCategories.foreach(c => leftImpurityAgg.subtract(aggStats(c.toInt)))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just zero things out; no need for subtraction.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

split.rightCategories.foreach(c => rightImpurityAgg.subtract(aggStats(c.toInt)))
}

val bestFeatureSplit = bestSplit match {
case Some(split) =>
new CategoricalSplit(featureIndex, split.leftCategories, featureArity)
case None =>
throw new AssertionError("Unknown error in AltDT unordered categorical split selection")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be an unknown error. It will happen if no splits exceed minInfoGain. It should return None for the split.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove newline

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

}
val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg)
val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity,
fullImpurityAgg.getCalculator, bestLeftImpurityAgg.getCalculator,
bestRightImpurityAgg.getCalculator)
(Some(bestFeatureSplit), bestImpurityStats)
}
}

/**
* Returns all possible subsets of features for categorical splits.
*/
private def findSplits(
featureIndex: Int,
featureArity: Int,
metadata: AltDTMetadata): Array[CategoricalSplit] = {
// Unordered features
// 2^(featureArity - 1) - 1 combinations
val numSplits = (1 << (featureArity - 1)) - 1
val splits = new Array[CategoricalSplit](numSplits)

var splitIndex = 0
while (splitIndex < numSplits) {
val categories: List[Double] =
extractMultiClassCategories(splitIndex + 1, featureArity)
splits(splitIndex) =
new CategoricalSplit(featureIndex, categories.toArray, featureArity)
splitIndex += 1
}
splits
}

/**
* Nested method to extract list of eligible categories given an index. It extracts the
* position of ones in a binary representation of the input. If binary
* representation of an number is 01101 (13), the output list should (3.0, 2.0,
* 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
*/
private def extractMultiClassCategories(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't duplicate the method in RandomForest; just call it from here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

input: Int,
maxFeatureValue: Int): List[Double] = {
var categories = List[Double]()
var j = 0
var bitShiftedInput = input
while (j < maxFeatureValue) {
if (bitShiftedInput % 2 != 0) {
// updating the list of categories.
categories = j.toDouble :: categories
}
// Right shift by one
bitShiftedInput = bitShiftedInput >> 1
j += 1
}
categories
}

/**
* Choose splitting rule: feature value <= threshold
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,26 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(stats.valid)
}

// test("chooseUnorderedCategoricalSplit: basic case") { }
test("chooseUnorderedCategoricalSplit: basic case") {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add 1 more test, analogous to the others with "return bad split if we should not split" in the test name? That will cover the case where we expect to get valid stats back, even when we do not find a split.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

val featureIndex = 0
val featureArity = 4
val values = Seq(3.0, 1.0, 0.0, 2.0, 2.0)
val labels = Seq(0.0, 0.0, 1.0, 1.0, 1.0)
val impurity = Entropy
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
val (split, stats) = AltDT.chooseUnorderedCategoricalSplit(
featureIndex, values, labels, metadata, featureArity)
split match {
case Some(s: CategoricalSplit) =>
assert(s.featureIndex === featureIndex)
assert(s.leftCategories.toSet === Set(0.0, 2.0))
assert(s.rightCategories.toSet === Set(1.0, 3.0))
// TODO: test correctness of stats
case _ =>
throw new AssertionError(
s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}")
}
}

// test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { }

Expand Down