Skip to content

Commit f067d68

Browse files
committed
minor cleanup
Signed-off-by: Manish Amde <[email protected]>
1 parent c0e522b commit f067d68

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
4141
logDebug("numSplits = " + bins(0).length)
4242
strategy.numBins = bins(0).length
4343

44-
//TODO: Level-wise training of tree and obtain Decision Tree model
4544
val maxDepth = strategy.maxDepth
4645

4746
val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1
@@ -62,7 +61,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
6261
logDebug("#####################################")
6362

6463
//Find best split for all nodes at a level
65-
val numNodes= scala.math.pow(2,level).toInt
6664
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins)
6765

6866
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){
@@ -105,7 +103,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
105103
private def extractInfoForLowerLevels(level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], filters: Array[List[Filter]]) {
106104
for (i <- 0 to 1) {
107105

108-
val nodeIndex = (scala.math.pow(2, level + 1)).toInt - 1 + 2 * index + i
106+
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
109107

110108
if (level < maxDepth - 1) {
111109

@@ -205,7 +203,6 @@ object DecisionTree extends Serializable with Logging {
205203
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = {
206204

207205
if (isFeatureContinuous){
208-
//TODO: Do binary search
209206
for (binIndex <- 0 until strategy.numBins) {
210207
val bin = bins(featureIndex)(binIndex)
211208
val lowThreshold = bin.lowSplit.threshold
@@ -250,9 +247,12 @@ object DecisionTree extends Serializable with Logging {
250247
val shift = 1 + numFeatures * nodeIndex
251248
if (!sampleValid) {
252249
//Add to invalid bin index -1
253-
for (featureIndex <- 0 until numFeatures) {
254-
arr(shift+featureIndex) = -1
255-
//TODO: Break since marking one bin is sufficient
250+
breakable {
251+
for (featureIndex <- 0 until numFeatures) {
252+
arr(shift+featureIndex) = -1
253+
//Breaking since marking one bin is sufficient
254+
break()
255+
}
256256
}
257257
} else {
258258
for (featureIndex <- 0 until numFeatures) {
@@ -318,7 +318,6 @@ object DecisionTree extends Serializable with Logging {
318318
def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = {
319319
strategy.algo match {
320320
case Classification => classificationBinSeqOp(arr, agg)
321-
//TODO: Implement this
322321
case Regression => regressionBinSeqOp(arr, agg)
323322
}
324323
agg
@@ -599,7 +598,6 @@ object DecisionTree extends Serializable with Logging {
599598

600599
logDebug("maxBins = " + numBins)
601600
//Calculate the number of sample for approximate quantile calculation
602-
//TODO: Justify this calculation
603601
val requiredSamples = numBins*numBins
604602
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
605603
logDebug("fraction of data used for calculating quantiles = " + fraction)
@@ -624,7 +622,6 @@ object DecisionTree extends Serializable with Logging {
624622
val stride : Double = numSamples.toDouble/numBins
625623
logDebug("stride = " + stride)
626624
for (index <- 0 until numBins-1) {
627-
//TODO: Investigate this
628625
val sampleIndex = (index+1)*stride.toInt
629626
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List())
630627
splits(featureIndex)(index) = split

0 commit comments

Comments
 (0)