@@ -681,36 +681,47 @@ object DecisionTree extends Serializable with Logging {
681681 topImpurity : Double ): InformationGainStats = {
682682 strategy.algo match {
683683 case Classification =>
684- // TODO: Modify here
685- val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0 )
686- val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1 )
687- val leftCount = left0Count + left1Count
688-
689- val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0 )
690- val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1 )
691- val rightCount = right0Count + right1Count
684+ var classIndex = 0
685+ val leftCounts : Array [Double ] = new Array [Double ](numClasses)
686+ val rightCounts : Array [Double ] = new Array [Double ](numClasses)
687+ var leftTotalCount = 0.0
688+ var rightTotalCount = 0.0
689+ while (classIndex < numClasses) {
690+ val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
691+ val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
692+ leftCounts(classIndex) = leftClassCount
693+ leftTotalCount += leftClassCount
694+ rightCounts(classIndex) = rightClassCount
695+ rightTotalCount += rightClassCount
696+ classIndex += 1
697+ }
692698
693699 val impurity = {
694700 if (level > 0 ) {
695701 topImpurity
696702 } else {
697703 // Calculate impurity for root node.
698- strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
704+ val rootNodeCounts = new Array [Double ](numClasses)
705+ var classIndex = 0
706+ while (classIndex < numClasses) {
707+ rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
708+ }
709+ strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
699710 }
700711 }
701712
702- if (leftCount == 0 ) {
713+ if (leftTotalCount == 0 ) {
703714 return new InformationGainStats (0 , topImpurity, Double .MinValue , topImpurity,1 )
704715 }
705- if (rightCount == 0 ) {
716+ if (rightTotalCount == 0 ) {
706717 return new InformationGainStats (0 , topImpurity, topImpurity, Double .MinValue ,0 )
707718 }
708719
709- val leftImpurity = strategy.impurity.calculate(left0Count, left1Count )
710- val rightImpurity = strategy.impurity.calculate(right0Count, right1Count )
720+ val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount )
721+ val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount )
711722
712- val leftWeight = leftCount .toDouble / (leftCount + rightCount )
713- val rightWeight = rightCount .toDouble / (leftCount + rightCount )
723+ val leftWeight = leftTotalCount .toDouble / (leftTotalCount + rightTotalCount )
724+ val rightWeight = rightTotalCount .toDouble / (leftTotalCount + rightTotalCount )
714725
715726 val gain = {
716727 if (level > 0 ) {
@@ -720,7 +731,8 @@ object DecisionTree extends Serializable with Logging {
720731 }
721732 }
722733
723- val predict = (left1Count + right1Count) / (leftCount + rightCount)
734+ // TODO: Make modification here
735+ val predict = (leftCounts(1 ) + rightCounts(1 )) / (leftTotalCount + rightTotalCount)
724736
725737 new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict)
726738 case Regression =>
@@ -782,7 +794,6 @@ object DecisionTree extends Serializable with Logging {
782794 binData : Array [Double ]): (Array [Array [Array [Double ]]], Array [Array [Array [Double ]]]) = {
783795 strategy.algo match {
784796 case Classification =>
785- // TODO: Multiclass modification here
786797
787798 // Initialize left and right split aggregates.
788799 val leftNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
@@ -793,17 +804,19 @@ object DecisionTree extends Serializable with Logging {
793804 while (featureIndex < numFeatures){
794805 val numCategories = strategy.categoricalFeaturesInfo(featureIndex)
795806 val maxSplits = math.pow(2 , numCategories) - 1
796- var i = 0
797- // TODO: Add multiclass case here
798- while (i < maxSplits) {
807+ var splitIndex = 0
808+ while (splitIndex < maxSplits) {
799809 var classIndex = 0
800810 while (classIndex < numClasses) {
801811 // shift for this featureIndex
802812 val shift = numClasses * featureIndex * numBins
803-
813+ leftNodeAgg(featureIndex)(splitIndex)(classIndex)
814+ = binData(shift + classIndex)
815+ rightNodeAgg(featureIndex)(splitIndex)(classIndex)
816+ = binData(shift + numClasses + classIndex)
804817 classIndex += 1
805818 }
806- i += 1
819+ splitIndex += 1
807820 }
808821 featureIndex += 1
809822 }
@@ -931,8 +944,6 @@ object DecisionTree extends Serializable with Logging {
931944 binData : Array [Double ],
932945 nodeImpurity : Double ): (Split , InformationGainStats ) = {
933946
934- // TODO: Multiclass modification here
935-
936947 logDebug(" node impurity = " + nodeImpurity)
937948
938949 // Extract left right node aggregates.
@@ -977,9 +988,8 @@ object DecisionTree extends Serializable with Logging {
977988 def getBinDataForNode (node : Int ): Array [Double ] = {
978989 strategy.algo match {
979990 case Classification =>
980- // TODO: Multiclass modification here
981- val shift = 2 * node * numBins * numFeatures
982- val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
991+ val shift = numClasses * node * numBins * numFeatures
992+ val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
983993 binsForNode
984994 case Regression =>
985995 val shift = 3 * node * numBins * numFeatures
0 commit comments