|
17 | 17 |
|
18 | 18 | package org.apache.spark.mllib.tree |
19 | 19 |
|
20 | | -import scala.util.control.Breaks._ |
21 | | - |
22 | 20 | import org.apache.spark.annotation.Experimental |
23 | 21 | import org.apache.spark.{Logging, SparkContext} |
24 | 22 | import org.apache.spark.SparkContext._ |
@@ -82,32 +80,32 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo |
82 | 80 | * still survived the filters of the parent nodes. |
83 | 81 | */ |
84 | 82 |
|
85 | | - // TODO: Convert for loop to while loop |
86 | | - breakable { |
87 | | - for (level <- 0 until maxDepth) { |
88 | | - |
89 | | - logDebug("#####################################") |
90 | | - logDebug("level = " + level) |
91 | | - logDebug("#####################################") |
92 | | - |
93 | | - // Find best split for all nodes at a level. |
94 | | - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, |
95 | | - level, filters, splits, bins) |
96 | | - |
97 | | - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { |
98 | | - // Extract info for nodes at the current level. |
99 | | - extractNodeInfo(nodeSplitStats, level, index, nodes) |
100 | | - // Extract info for nodes at the next lower level. |
101 | | - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, |
102 | | - filters) |
103 | | - logDebug("final best split = " + nodeSplitStats._1) |
104 | | - } |
105 | | - require(scala.math.pow(2, level) == splitsStatsForLevel.length) |
106 | | - // Check whether all the nodes at the current level at leaves. |
107 | | - val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) |
108 | | - logDebug("all leaf = " + allLeaf) |
109 | | - if (allLeaf) break // no more tree construction |
| 83 | + var level = 0 |
| 84 | + var break = false |
| 85 | + while (level < maxDepth && !break) { |
| 86 | + |
| 87 | + logDebug("#####################################") |
| 88 | + logDebug("level = " + level) |
| 89 | + logDebug("#####################################") |
| 90 | + |
| 91 | + // Find best split for all nodes at a level. |
| 92 | + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, |
| 93 | + level, filters, splits, bins) |
| 94 | + |
| 95 | + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { |
| 96 | + // Extract info for nodes at the current level. |
| 97 | + extractNodeInfo(nodeSplitStats, level, index, nodes) |
| 98 | + // Extract info for nodes at the next lower level. |
| 99 | + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, |
| 100 | + filters) |
| 101 | + logDebug("final best split = " + nodeSplitStats._1) |
110 | 102 | } |
| 103 | + require(scala.math.pow(2, level) == splitsStatsForLevel.length) |
| 104 | + // Check whether all the nodes at the current level at leaves. |
| 105 | + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) |
| 106 | + logDebug("all leaf = " + allLeaf) |
| 107 | + if (allLeaf) break = true // no more tree construction |
| 108 | + else level += 1 |
111 | 109 | } |
112 | 110 |
|
113 | 111 | // Initialize the top or root node of the tree. |
|
0 commit comments