@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
130130
131131 // Find best split for all nodes at a level.
132132 timer.start(" findBestSplits" )
133- val splitsStatsForLevel : Array [(Split , InformationGainStats )] =
133+ val splitsStatsForLevel : Array [(Split , InformationGainStats , Predict )] =
134134 DecisionTree .findBestSplits(treeInput, parentImpurities,
135135 metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
136136 timer.stop(" findBestSplits" )
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
143143 timer.start(" extractNodeInfo" )
144144 val split = nodeSplitStats._1
145145 val stats = nodeSplitStats._2
146+ val predict = nodeSplitStats._3
146147 val isLeaf = (stats.gain <= 0 ) || (level == strategy.maxDepth)
147- val node = new Node (nodeIndex, stats. predict, isLeaf, Some (split), None , None , Some (stats))
148+ val node = new Node (nodeIndex, predict, isLeaf, Some (split), None , None , Some (stats))
148149 logDebug(" Node = " + node)
149150 nodes(nodeIndex) = node
150151 timer.stop(" extractNodeInfo" )
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
425426 splits : Array [Array [Split ]],
426427 bins : Array [Array [Bin ]],
427428 maxLevelForSingleGroup : Int ,
428- timer : TimeTracker = new TimeTracker ): Array [(Split , InformationGainStats )] = {
429+ timer : TimeTracker = new TimeTracker ): Array [(Split , InformationGainStats , Predict )] = {
429430 // split into groups to avoid memory overflow during aggregation
430431 if (level > maxLevelForSingleGroup) {
431432 // When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
434435 // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
435436 val numGroups = 1 << level - maxLevelForSingleGroup
436437 logDebug(" numGroups = " + numGroups)
437- var bestSplits = new Array [(Split , InformationGainStats )](0 )
438+ var bestSplits = new Array [(Split , InformationGainStats , Predict )](0 )
438439 // Iterate over each group of nodes at a level.
439440 var groupIndex = 0
440441 while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
605606 bins : Array [Array [Bin ]],
606607 timer : TimeTracker ,
607608 numGroups : Int = 1 ,
608- groupIndex : Int = 0 ): Array [(Split , InformationGainStats )] = {
609+ groupIndex : Int = 0 ): Array [(Split , InformationGainStats , Predict )] = {
609610
610611 /*
611612 * The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
705706
706707 // Calculate best splits for all nodes at a given level
707708 timer.start(" chooseSplits" )
708- val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
709+ val bestSplits = new Array [(Split , InformationGainStats , Predict )](numNodes)
709710 // Iterating over all nodes at this level
710711 var nodeIndex = 0
711712 while (nodeIndex < numNodes) {
@@ -747,18 +748,16 @@ object DecisionTree extends Serializable with Logging {
747748
748749 val totalCount = leftCount + rightCount
749750
750- val parentNodeAgg = leftImpurityCalculator.copy
751- parentNodeAgg.add(rightImpurityCalculator)
751+
752752 // impurity of parent node
753753 val impurity = if (level > 0 ) {
754754 topImpurity
755755 } else {
756+ val parentNodeAgg = leftImpurityCalculator.copy
757+ parentNodeAgg.add(rightImpurityCalculator)
756758 parentNodeAgg.calculate()
757759 }
758760
759- val predict = parentNodeAgg.predict
760- val prob = parentNodeAgg.prob(predict)
761-
762761 val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
763762 val rightImpurity = rightImpurityCalculator.calculate()
764763
@@ -770,7 +769,18 @@ object DecisionTree extends Serializable with Logging {
770769 return InformationGainStats .invalidInformationGainStats
771770 }
772771
773- new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict, prob)
772+ new InformationGainStats (gain, impurity, leftImpurity, rightImpurity)
773+ }
774+
775+ private def calculatePredict (
776+ leftImpurityCalculator : ImpurityCalculator ,
777+ rightImpurityCalculator : ImpurityCalculator ): Predict = {
778+ val parentNodeAgg = leftImpurityCalculator.copy
779+ parentNodeAgg.add(rightImpurityCalculator)
780+ val predict = parentNodeAgg.predict
781+ val prob = parentNodeAgg.prob(predict)
782+
783+ new Predict (predict, prob)
774784 }
775785
776786 /**
@@ -786,12 +796,14 @@ object DecisionTree extends Serializable with Logging {
786796 nodeImpurity : Double ,
787797 level : Int ,
788798 metadata : DecisionTreeMetadata ,
789- splits : Array [Array [Split ]]): (Split , InformationGainStats ) = {
799+ splits : Array [Array [Split ]]): (Split , InformationGainStats , Predict ) = {
790800
791801 logDebug(" node impurity = " + nodeImpurity)
792802
803+ var predict : Option [Predict ] = None
804+
793805 // For each (feature, split), calculate the gain, and select the best (feature, split).
794- Range (0 , metadata.numFeatures).map { featureIndex =>
806+ val (bestSplit, bestSplitStats) = Range (0 , metadata.numFeatures).map { featureIndex =>
795807 val numSplits = metadata.numSplits(featureIndex)
796808 if (metadata.isContinuous(featureIndex)) {
797809 // Cumulative sum (scanLeft) of bin statistics.
@@ -809,6 +821,7 @@ object DecisionTree extends Serializable with Logging {
809821 val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
810822 val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
811823 rightChildStats.subtract(leftChildStats)
824+ predict = Some (predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
812825 val gainStats =
813826 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
814827 (splitIdx, gainStats)
@@ -825,6 +838,7 @@ object DecisionTree extends Serializable with Logging {
825838 Range (0 , numSplits).map { splitIndex =>
826839 val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
827840 val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
841+ predict = Some (predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
828842 val gainStats =
829843 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
830844 (splitIndex, gainStats)
@@ -899,6 +913,7 @@ object DecisionTree extends Serializable with Logging {
899913 val rightChildStats =
900914 binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
901915 rightChildStats.subtract(leftChildStats)
916+ predict = Some (predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
902917 val gainStats =
903918 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
904919 (splitIndex, gainStats)
@@ -913,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
913928 (bestFeatureSplit, bestFeatureGainStats)
914929 }
915930 }.maxBy(_._2.gain)
931+
932+ require(predict.isDefined, " must calculate predict for each node" )
933+
934+ (bestSplit, bestSplitStats, predict.get)
916935 }
917936
918937 /**
0 commit comments