@@ -41,7 +41,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
4141 logDebug(" numSplits = " + bins(0 ).length)
4242 strategy.numBins = bins(0 ).length
4343
44- // TODO: Level-wise training of tree and obtain Decision Tree model
4544 val maxDepth = strategy.maxDepth
4645
4746 val maxNumNodes = scala.math.pow(2 ,maxDepth).toInt - 1
@@ -62,7 +61,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
6261 logDebug(" #####################################" )
6362
6463 // Find best split for all nodes at a level
65- val numNodes = scala.math.pow(2 ,level).toInt
6664 val splitsStatsForLevel = DecisionTree .findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins)
6765
6866 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){
@@ -105,7 +103,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
105103 private def extractInfoForLowerLevels (level : Int , index : Int , maxDepth : Int , nodeSplitStats : (Split , InformationGainStats ), parentImpurities : Array [Double ], filters : Array [List [Filter ]]) {
106104 for (i <- 0 to 1 ) {
107105
108- val nodeIndex = ( scala.math.pow(2 , level + 1 ) ).toInt - 1 + 2 * index + i
106+ val nodeIndex = scala.math.pow(2 , level + 1 ).toInt - 1 + 2 * index + i
109107
110108 if (level < maxDepth - 1 ) {
111109
@@ -205,7 +203,6 @@ object DecisionTree extends Serializable with Logging {
205203 def findBin (featureIndex : Int , labeledPoint : LabeledPoint , isFeatureContinuous : Boolean ) : Int = {
206204
207205 if (isFeatureContinuous){
208- // TODO: Do binary search
209206 for (binIndex <- 0 until strategy.numBins) {
210207 val bin = bins(featureIndex)(binIndex)
211208 val lowThreshold = bin.lowSplit.threshold
@@ -250,9 +247,12 @@ object DecisionTree extends Serializable with Logging {
250247 val shift = 1 + numFeatures * nodeIndex
251248 if (! sampleValid) {
252249 // Add to invalid bin index -1
253- for (featureIndex <- 0 until numFeatures) {
254- arr(shift+ featureIndex) = - 1
255- // TODO: Break since marking one bin is sufficient
250+ breakable {
251+ for (featureIndex <- 0 until numFeatures) {
252+ arr(shift+ featureIndex) = - 1
253+ // Breaking since marking one bin is sufficient
254+ break()
255+ }
256256 }
257257 } else {
258258 for (featureIndex <- 0 until numFeatures) {
@@ -318,7 +318,6 @@ object DecisionTree extends Serializable with Logging {
318318 def binSeqOp (agg : Array [Double ], arr : Array [Double ]) : Array [Double ] = {
319319 strategy.algo match {
320320 case Classification => classificationBinSeqOp(arr, agg)
321- // TODO: Implement this
322321 case Regression => regressionBinSeqOp(arr, agg)
323322 }
324323 agg
@@ -599,7 +598,6 @@ object DecisionTree extends Serializable with Logging {
599598
600599 logDebug(" maxBins = " + numBins)
601600 // Calculate the number of sample for approximate quantile calculation
602- // TODO: Justify this calculation
603601 val requiredSamples = numBins* numBins
604602 val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
605603 logDebug(" fraction of data used for calculating quantiles = " + fraction)
@@ -624,7 +622,6 @@ object DecisionTree extends Serializable with Logging {
624622 val stride : Double = numSamples.toDouble/ numBins
625623 logDebug(" stride = " + stride)
626624 for (index <- 0 until numBins- 1 ) {
627- // TODO: Investigate this
628625 val sampleIndex = (index+ 1 )* stride.toInt
629626 val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous , List ())
630627 splits(featureIndex)(index) = split
0 commit comments