@@ -367,18 +367,18 @@ object DecisionTree extends Serializable with Logging {
367367
368368 def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]],
369369 featureIndex : Int ,
370- index : Int ,
370+ splitIndex : Int ,
371371 rightNodeAgg : Array [Array [Double ]],
372372 topImpurity : Double ) : InformationGainStats = {
373373 strategy.algo match {
374374 case Classification => {
375375
376- val left0Count = leftNodeAgg(featureIndex)(2 * index )
377- val left1Count = leftNodeAgg(featureIndex)(2 * index + 1 )
376+ val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex )
377+ val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
378378 val leftCount = left0Count + left1Count
379379
380- val right0Count = rightNodeAgg(featureIndex)(2 * index )
381- val right1Count = rightNodeAgg(featureIndex)(2 * index + 1 )
380+ val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex )
381+ val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1 )
382382 val rightCount = right0Count + right1Count
383383
384384 val impurity = if (level > 0 ) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
@@ -405,13 +405,13 @@ object DecisionTree extends Serializable with Logging {
405405 new InformationGainStats (gain,impurity,leftImpurity,rightImpurity,predict)
406406 }
407407 case Regression => {
408- val leftCount = leftNodeAgg(featureIndex)(3 * index )
409- val leftSum = leftNodeAgg(featureIndex)(3 * index + 1 )
410- val leftSumSquares = leftNodeAgg(featureIndex)(3 * index + 2 )
408+ val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex )
409+ val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1 )
410+ val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2 )
411411
412- val rightCount = rightNodeAgg(featureIndex)(3 * index )
413- val rightSum = rightNodeAgg(featureIndex)(3 * index + 1 )
414- val rightSumSquares = rightNodeAgg(featureIndex)(3 * index + 2 )
412+ val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex )
413+ val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1 )
414+ val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2 )
415415
416416 val impurity = if (level > 0 ) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares)
417417
@@ -463,9 +463,9 @@ object DecisionTree extends Serializable with Logging {
463463 leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
464464 = binData(shift + 2 * splitIndex + 1 ) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1 )
465465 rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex))
466- = binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
466+ = binData(shift + (2 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
467467 rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1 )
468- = binData(shift + (2 * (numBins - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1 )
468+ = binData(shift + (2 * (numBins - 2 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1 )
469469 }
470470 }
471471 (leftNodeAgg, rightNodeAgg)
@@ -490,11 +490,11 @@ object DecisionTree extends Serializable with Logging {
490490 leftNodeAgg(featureIndex)(3 * splitIndex + 2 )
491491 = binData(shift + 3 * splitIndex + 2 ) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2 )
492492 rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex))
493- = binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
493+ = binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
494494 rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1 )
495- = binData(shift + (3 * (numBins - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1 )
495+ = binData(shift + (3 * (numBins - 2 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1 )
496496 rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2 )
497- = binData(shift + (3 * (numBins - 1 - splitIndex) + 2 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2 )
497+ = binData(shift + (3 * (numBins - 2 - splitIndex) + 2 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2 )
498498 }
499499 }
500500 (leftNodeAgg, rightNodeAgg)
@@ -508,9 +508,9 @@ object DecisionTree extends Serializable with Logging {
508508 val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
509509
510510 for (featureIndex <- 0 until numFeatures) {
511- for (index <- 0 until numBins - 1 ) {
511+ for (splitIndex <- 0 until numBins - 1 ) {
512512 // logDebug("splitIndex = " + index)
513- gains(featureIndex)(index ) = calculateGainForSplit(leftNodeAgg, featureIndex, index , rightNodeAgg, nodeImpurity)
513+ gains(featureIndex)(splitIndex ) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex , rightNodeAgg, nodeImpurity)
514514 }
515515 }
516516 gains
@@ -544,6 +544,8 @@ object DecisionTree extends Serializable with Logging {
544544 (bestFeatureIndex,bestSplitIndex,bestGainStats)
545545 }
546546
547+ logDebug(" best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
548+ logDebug(" best split bin = " + splits(bestFeatureIndex)(bestSplitIndex))
547549 (splits(bestFeatureIndex)(bestSplitIndex),gainStats)
548550 }
549551
@@ -614,13 +616,14 @@ object DecisionTree extends Serializable with Logging {
614616
615617 // Find all splits
616618 for (featureIndex <- 0 until numFeatures){
617- val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
618- if (isFeatureContinous ) {
619+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
620+ if (isFeatureContinuous ) {
619621 val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
620622
621623 val stride : Double = numSamples.toDouble/ numBins
622624 logDebug(" stride = " + stride)
623625 for (index <- 0 until numBins- 1 ) {
626+ // TODO: Investigate this
624627 val sampleIndex = (index+ 1 )* stride.toInt
625628 val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous , List ())
626629 splits(featureIndex)(index) = split
0 commit comments