@@ -89,8 +89,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
8989 s " DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth. " )
9090 // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
9191 val maxNumNodesPlus1 = Node .startIndexInLevel(maxDepth + 1 )
92- // Initialize an array to hold parent impurity calculations for each node.
93- val parentImpurities = new Array [Double ](maxNumNodesPlus1)
9492 // dummy value for top node (updated during first split calculation)
9593 val nodes = new Array [Node ](maxNumNodesPlus1)
9694
@@ -131,7 +129,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
131129 // Find best split for all nodes at a level.
132130 timer.start(" findBestSplits" )
133131 val splitsStatsForLevel : Array [(Split , InformationGainStats )] =
134- DecisionTree .findBestSplits(treeInput, parentImpurities,
132+ DecisionTree .findBestSplits(treeInput,
135133 metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
136134 timer.stop(" findBestSplits" )
137135
@@ -158,20 +156,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
158156 nodes(parentNodeIndex).rightNode = Some (nodes(nodeIndex))
159157 }
160158 }
161- // Extract info for nodes at the next lower level.
162- timer.start(" extractInfoForLowerLevels" )
163159 if (level < maxDepth) {
164- val leftChildIndex = Node .leftChildIndex(nodeIndex)
165- val leftImpurity = stats.leftImpurity
166- logDebug(" leftChildIndex = " + leftChildIndex + " , impurity = " + leftImpurity)
167- parentImpurities(leftChildIndex) = leftImpurity
168-
169- val rightChildIndex = Node .rightChildIndex(nodeIndex)
170- val rightImpurity = stats.rightImpurity
171- logDebug(" rightChildIndex = " + rightChildIndex + " , impurity = " + rightImpurity)
172- parentImpurities(rightChildIndex) = rightImpurity
160+ logDebug(" leftChildIndex = " + Node .leftChildIndex(nodeIndex) +
161+ " , impurity = " + stats.leftImpurity)
162+ logDebug(" rightChildIndex = " + Node .rightChildIndex(nodeIndex) +
163+ " , impurity = " + stats.rightImpurity)
173164 }
174- timer.stop(" extractInfoForLowerLevels" )
175165 logDebug(" final best split = " + split)
176166 }
177167 require(Node .maxNodesInLevel(level) == splitsStatsForLevel.length)
@@ -189,17 +179,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
189179 logDebug(" Extracting tree model" )
190180 logDebug(" #####################################" )
191181
192- // Initialize the top or root node of the tree.
193- val topNode = nodes(1 )
194- // Build the full tree using the node info calculated in the level-wise best split calculations.
195- topNode.build(nodes)
196-
197182 timer.stop(" total" )
198183
199184 logInfo(" Internal timing for DecisionTree:" )
200185 logInfo(s " $timer" )
201186
202- new DecisionTreeModel (topNode , strategy.algo)
187+ new DecisionTreeModel (nodes( 1 ) , strategy.algo)
203188 }
204189
205190}
@@ -408,7 +393,6 @@ object DecisionTree extends Serializable with Logging {
408393 * multiple groups if the level-wise training task could lead to memory overflow.
409394 *
410395 * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint ]]
411- * @param parentImpurities Impurities for all parent nodes for the current level
412396 * @param metadata Learning and dataset metadata
413397 * @param level Level of the tree
414398 * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
@@ -418,7 +402,6 @@ object DecisionTree extends Serializable with Logging {
418402 */
419403 private [tree] def findBestSplits (
420404 input : RDD [TreePoint ],
421- parentImpurities : Array [Double ],
422405 metadata : DecisionTreeMetadata ,
423406 level : Int ,
424407 nodes : Array [Node ],
@@ -438,14 +421,14 @@ object DecisionTree extends Serializable with Logging {
438421 // Iterate over each group of nodes at a level.
439422 var groupIndex = 0
440423 while (groupIndex < numGroups) {
441- val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level,
424+ val bestSplitsForGroup = findBestSplitsPerGroup(input, metadata, level,
442425 nodes, splits, bins, timer, numGroups, groupIndex)
443426 bestSplits = Array .concat(bestSplits, bestSplitsForGroup)
444427 groupIndex += 1
445428 }
446429 bestSplits
447430 } else {
448- findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer)
431+ findBestSplitsPerGroup(input, metadata, level, nodes, splits, bins, timer)
449432 }
450433 }
451434
@@ -585,7 +568,6 @@ object DecisionTree extends Serializable with Logging {
585568 * Returns an array of optimal splits for a group of nodes at a given level
586569 *
587570 * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint ]]
588- * @param parentImpurities Impurities for all parent nodes for the current level
589571 * @param metadata Learning and dataset metadata
590572 * @param level Level of the tree
591573 * @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
@@ -597,7 +579,6 @@ object DecisionTree extends Serializable with Logging {
597579 */
598580 private def findBestSplitsPerGroup (
599581 input : RDD [TreePoint ],
600- parentImpurities : Array [Double ],
601582 metadata : DecisionTreeMetadata ,
602583 level : Int ,
603584 nodes : Array [Node ],
@@ -709,10 +690,8 @@ object DecisionTree extends Serializable with Logging {
709690 // Iterating over all nodes at this level
710691 var nodeIndex = 0
711692 while (nodeIndex < numNodes) {
712- val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex)
713- logDebug(" node impurity = " + nodeImpurity)
714693 bestSplits(nodeIndex) =
715- binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits)
694+ binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
716695 logDebug(" best split = " + bestSplits(nodeIndex)._1)
717696 nodeIndex += 1
718697 }
@@ -725,13 +704,11 @@ object DecisionTree extends Serializable with Logging {
725704 * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
726705 * @param leftImpurityCalculator left node aggregates for this (feature, split)
727706 * @param rightImpurityCalculator right node aggregate for this (feature, split)
728- * @param topImpurity impurity of the parent node
729707 * @return information gain and statistics for all splits
730708 */
731709 private def calculateGainForSplit (
732710 leftImpurityCalculator : ImpurityCalculator ,
733711 rightImpurityCalculator : ImpurityCalculator ,
734- topImpurity : Double ,
735712 level : Int ,
736713 metadata : DecisionTreeMetadata ): InformationGainStats = {
737714
@@ -741,18 +718,13 @@ object DecisionTree extends Serializable with Logging {
741718 val totalCount = leftCount + rightCount
742719 if (totalCount == 0 ) {
743720 // Return arbitrary prediction.
744- return new InformationGainStats (0 , topImpurity, topImpurity, topImpurity , 0 )
721+ return new InformationGainStats (0 , 0 , 0 , 0 , 0 )
745722 }
746723
747724 val parentNodeAgg = leftImpurityCalculator.copy
748725 parentNodeAgg.add(rightImpurityCalculator)
749- // impurity of parent node
750- val impurity = if (level > 0 ) {
751- topImpurity
752- } else {
753- parentNodeAgg.calculate()
754- }
755726
727+ val impurity = parentNodeAgg.calculate()
756728 val predict = parentNodeAgg.predict
757729 val prob = parentNodeAgg.prob(predict)
758730
@@ -771,19 +743,15 @@ object DecisionTree extends Serializable with Logging {
771743 * Find the best split for a node.
772744 * @param binAggregates Bin statistics.
773745 * @param nodeIndex Index for node to split in this (level, group).
774- * @param nodeImpurity Impurity of the node (nodeIndex).
775746 * @return tuple for best split: (Split, information gain)
776747 */
777748 private def binsToBestSplit (
778749 binAggregates : DTStatsAggregator ,
779750 nodeIndex : Int ,
780- nodeImpurity : Double ,
781751 level : Int ,
782752 metadata : DecisionTreeMetadata ,
783753 splits : Array [Array [Split ]]): (Split , InformationGainStats ) = {
784754
785- logDebug(" node impurity = " + nodeImpurity)
786-
787755 // For each (feature, split), calculate the gain, and select the best (feature, split).
788756 Range (0 , metadata.numFeatures).map { featureIndex =>
789757 val numSplits = metadata.numSplits(featureIndex)
@@ -803,8 +771,7 @@ object DecisionTree extends Serializable with Logging {
803771 val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
804772 val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
805773 rightChildStats.subtract(leftChildStats)
806- val gainStats =
807- calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
774+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
808775 (splitIdx, gainStats)
809776 }.maxBy(_._2.gain)
810777 (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -816,8 +783,7 @@ object DecisionTree extends Serializable with Logging {
816783 Range (0 , numSplits).map { splitIndex =>
817784 val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
818785 val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
819- val gainStats =
820- calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
786+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
821787 (splitIndex, gainStats)
822788 }.maxBy(_._2.gain)
823789 (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -887,8 +853,7 @@ object DecisionTree extends Serializable with Logging {
887853 val rightChildStats =
888854 binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
889855 rightChildStats.subtract(leftChildStats)
890- val gainStats =
891- calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
856+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
892857 (splitIndex, gainStats)
893858 }.maxBy(_._2.gain)
894859 val categoriesForSplit =
0 commit comments