1717
1818package org .apache .spark .mllib .tree
1919
20- import scala .util .control .Breaks ._
21-
2220import org .apache .spark .annotation .Experimental
2321import org .apache .spark .{Logging , SparkContext }
2422import org .apache .spark .SparkContext ._
@@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
8280 * still survived the filters of the parent nodes.
8381 */
8482
85- // TODO: Convert for loop to while loop
86- breakable {
87- for (level <- 0 until maxDepth) {
88-
89- logDebug(" #####################################" )
90- logDebug(" level = " + level)
91- logDebug(" #####################################" )
92-
93- // Find best split for all nodes at a level.
94- val splitsStatsForLevel = DecisionTree .findBestSplits(input, parentImpurities, strategy,
95- level, filters, splits, bins)
96-
97- for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
98- // Extract info for nodes at the current level.
99- extractNodeInfo(nodeSplitStats, level, index, nodes)
100- // Extract info for nodes at the next lower level.
101- extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
102- filters)
103- logDebug(" final best split = " + nodeSplitStats._1)
104- }
105- require(scala.math.pow(2 , level) == splitsStatsForLevel.length)
106- // Check whether all the nodes at the current level at leaves.
107- val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
108- logDebug(" all leaf = " + allLeaf)
109- if (allLeaf) break // no more tree construction
83+ var level = 0
84+ var break = false
85+ while (level < maxDepth && ! break) {
86+
87+ logDebug(" #####################################" )
88+ logDebug(" level = " + level)
89+ logDebug(" #####################################" )
90+
91+ // Find best split for all nodes at a level.
92+ val splitsStatsForLevel = DecisionTree .findBestSplits(input, parentImpurities, strategy,
93+ level, filters, splits, bins)
94+
95+ for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
96+ // Extract info for nodes at the current level.
97+ extractNodeInfo(nodeSplitStats, level, index, nodes)
98+ // Extract info for nodes at the next lower level.
99+ extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
100+ filters)
101+ logDebug(" final best split = " + nodeSplitStats._1)
102+ }
103+ require(scala.math.pow(2 , level) == splitsStatsForLevel.length)
104+ // Check whether all the nodes at the current level at leaves.
105+ val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
106+ logDebug(" all leaf = " + allLeaf)
107+ if (allLeaf) {
108+ break = true // no more tree construction
109+ } else {
110+ level += 1
110111 }
111112 }
112113
@@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
146147 parentImpurities : Array [Double ],
147148 filters : Array [List [Filter ]]): Unit = {
148149 // 0 corresponds to the left child node and 1 corresponds to the right child node.
149- // TODO: Convert to while loop
150- for (i <- 0 to 1 ) {
150+ var i = 0
151+ while (i <= 1 ) {
151152 // Calculate the index of the node from the node level and the index at the current level.
152153 val nodeIndex = scala.math.pow(2 , level + 1 ).toInt - 1 + 2 * index + i
153154 if (level < maxDepth - 1 ) {
@@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
166167 logDebug(" Filter = " + filter)
167168 }
168169 }
170+ i += 1
169171 }
170172 }
171173}
0 commit comments