@@ -102,7 +102,7 @@ private[ml] object AltDT extends Logging {
102102 parentUID : Option [String ] = None ): DecisionTreeModel = {
103103 // TODO: Check validity of params
104104 val rootNode = trainImpl(input, strategy)
105- RandomForest .finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID)
105+ impl. RandomForest .finalizeTree(rootNode, strategy.algo, strategy.numClasses, parentUID)
106106 }
107107
108108 private [impl] def trainImpl (input : RDD [LabeledPoint ], strategy : Strategy ): Node = {
@@ -529,12 +529,114 @@ private[ml] object AltDT extends Logging {
529529 }
530530 }
531531
532+ /**
533+ * Find the best split for an unordered categorical feature at a single node.
534+ *
535+ * Algorithm:
536+ * - Considers all possible subsets (exponentially many)
537+ *
538+ * @param featureIndex Index of feature being split.
539+ * @param values Feature values at this node. Sorted in increasing order.
540+ * @param labels Labels corresponding to values, in the same order.
541+ * @return (best split, statistics for split) If the best split actually puts all instances
542+ * in one leaf node, then it will be set to None. The impurity stats maybe still be
543+ * useful, so they are returned.
544+ */
532545 private [impl] def chooseUnorderedCategoricalSplit (
533546 featureIndex : Int ,
534547 values : Seq [Double ],
535548 labels : Seq [Double ],
536549 metadata : AltDTMetadata ,
537- featureArity : Int ): (Option [Split ], ImpurityStats ) = ???
550+ featureArity : Int ): (Option [Split ], ImpurityStats ) = {
551+
552+ // Label stats for each category
553+ val aggStats = Array .tabulate[ImpurityAggregatorSingle ](featureArity)(
554+ _ => metadata.createImpurityAggregator())
555+ values.zip(labels).foreach { case (cat, label) =>
556+ // NOTE: we assume the values for categorical features are Ints in [0,featureArity)
557+ aggStats(cat.toInt).update(label)
558+ }
559+
560+ // Aggregated statistics for left part of split and entire split.
561+ val leftImpurityAgg = metadata.createImpurityAggregator()
562+ val fullImpurityAgg = metadata.createImpurityAggregator()
563+ aggStats.foreach(fullImpurityAgg.add)
564+ val fullImpurity = fullImpurityAgg.getCalculator.calculate()
565+
566+ if (featureArity == 1 ) {
567+ // All instances go right
568+ val impurityStats = new ImpurityStats (0.0 , fullImpurityAgg.getCalculator.calculate(),
569+ fullImpurityAgg.getCalculator, leftImpurityAgg.getCalculator,
570+ fullImpurityAgg.getCalculator)
571+ (None , impurityStats)
572+ } else {
573+ // TODO: We currently add and remove the stats for all categories for each split.
574+ // A better way to do it would be to consider splits in an order such that each iteration
575+ // only requires addition/removal of a single category and a single add/subtract to
576+ // leftCount and rightCount.
577+ // TODO: Use more efficient encoding such as gray codes
578+ val splits : Array [CategoricalSplit ] = findSplits(featureIndex, featureArity, metadata)
579+ var bestSplit : Option [CategoricalSplit ] = None
580+ val bestLeftImpurityAgg = leftImpurityAgg.deepCopy()
581+ var bestGain : Double = - 1.0
582+ val fullCount : Double = values.size
583+ for (split <- splits) {
584+ // Update left, right impurity stats
585+ split.leftCategories.foreach(c => leftImpurityAgg.add(aggStats(c.toInt)))
586+ val rightImpurityAgg = fullImpurityAgg.deepCopy().subtract(leftImpurityAgg)
587+ val leftCount = leftImpurityAgg.getCount
588+ val rightCount = rightImpurityAgg.getCount
589+ // Compute impurity
590+ val leftWeight = leftCount / fullCount
591+ val rightWeight = rightCount / fullCount
592+ val leftImpurity = leftImpurityAgg.getCalculator.calculate()
593+ val rightImpurity = rightImpurityAgg.getCalculator.calculate()
594+ val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
595+ if (gain > bestGain && gain > metadata.minInfoGain) {
596+ bestSplit = Some (split)
597+ leftImpurityAgg.stats.copyToArray(bestLeftImpurityAgg.stats)
598+ bestGain = gain
599+ }
600+ // Reset left impurity stats
601+ leftImpurityAgg.clear()
602+ }
603+
604+ val bestFeatureSplit = bestSplit match {
605+ case Some (split) => Some (
606+ new CategoricalSplit (featureIndex, split.leftCategories, featureArity))
607+ case None => None
608+
609+ }
610+ val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg)
611+ val bestImpurityStats = new ImpurityStats (bestGain, fullImpurity,
612+ fullImpurityAgg.getCalculator, bestLeftImpurityAgg.getCalculator,
613+ bestRightImpurityAgg.getCalculator)
614+ (bestFeatureSplit, bestImpurityStats)
615+ }
616+ }
617+
618+ /**
619+ * Returns all possible subsets of features for categorical splits.
620+ */
621+ private def findSplits (
622+ featureIndex : Int ,
623+ featureArity : Int ,
624+ metadata : AltDTMetadata ): Array [CategoricalSplit ] = {
625+ // Unordered features
626+ // 2^(featureArity - 1) - 1 combinations
627+ val numSplits = (1 << (featureArity - 1 )) - 1
628+ val splits = new Array [CategoricalSplit ](numSplits)
629+
630+ var splitIndex = 0
631+ while (splitIndex < numSplits) {
632+ val categories : List [Double ] =
633+ RandomForest .extractMultiClassCategories(splitIndex + 1 , featureArity)
634+ splits(splitIndex) =
635+ new CategoricalSplit (featureIndex, categories.toArray, featureArity)
636+ splitIndex += 1
637+ }
638+ splits
639+ }
538640
539641 /**
540642 * Choose splitting rule: feature value <= threshold
0 commit comments