@@ -385,6 +385,8 @@ object DecisionTree extends Serializable with Logging {
385385 logDebug(" numFeatures = " + numFeatures)
386386 val numBins = bins(0 ).length
387387 logDebug(" numBins = " + numBins)
388+ val numClasses = strategy.numClassesForClassification
389+ logDebug(" numClasses = " + numClasses)
388390
389391 // shift when more than one group is used at deep tree level
390392 val groupShift = numNodes * groupIndex
@@ -545,10 +547,10 @@ object DecisionTree extends Serializable with Logging {
545547 * incremented based upon whether the feature is classified as 0 or 1.
546548 *
547549 * @param agg Array[Double] storing aggregate calculation of size
548- * 2 * numSplits * numFeatures*numNodes for classification
550+ * numClasses * numSplits * numFeatures*numNodes for classification
549551 * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
550552 * @return Array[Double] storing aggregate calculation of size
551- * 2 * numSplits * numFeatures * numNodes for classification
553+ * numClasses * numSplits * numFeatures * numNodes for classification
552554 */
553555 def classificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) {
554556 // Iterate over all nodes.
@@ -562,16 +564,16 @@ object DecisionTree extends Serializable with Logging {
562564 val label = arr(0 )
563565 // Iterate over all features.
564566 var featureIndex = 0
565- // TODO: Multiclass modification here
566567 while (featureIndex < numFeatures) {
567568 // Find the bin index for this feature.
568569 val arrShift = 1 + numFeatures * nodeIndex
569570 val arrIndex = arrShift + featureIndex
570571 // Update the left or right count for one bin.
571- val aggShift = 2 * numBins * numFeatures * nodeIndex
572- val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
573- label match {
574- case n : Double => agg(aggIndex) = agg(aggIndex + n.toInt) + 1
572+ val aggShift = numClasses * numBins * numFeatures * nodeIndex
573+ val aggIndex
574+ = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
575+ label.toInt match {
576+ case n : Int => agg(aggIndex + n) = agg(aggIndex + n) + 1
575577 }
576578 featureIndex += 1
577579 }
@@ -632,7 +634,7 @@ object DecisionTree extends Serializable with Logging {
632634
633635 // Calculate bin aggregate length for classification or regression.
634636 val binAggregateLength = strategy.algo match {
635- case Classification => 2 * numBins * numFeatures * numNodes
637+ case Classification => numClasses * numBins * numFeatures * numNodes
636638 case Regression => 3 * numBins * numFeatures * numNodes
637639 }
638640 logDebug(" binAggregateLength = " + binAggregateLength)
@@ -672,20 +674,20 @@ object DecisionTree extends Serializable with Logging {
672674 * @return information gain and statistics for all splits
673675 */
674676 def calculateGainForSplit (
675- leftNodeAgg : Array [Array [Double ]],
677+ leftNodeAgg : Array [Array [Array [ Double ] ]],
676678 featureIndex : Int ,
677679 splitIndex : Int ,
678- rightNodeAgg : Array [Array [Double ]],
680+ rightNodeAgg : Array [Array [Array [ Double ] ]],
679681 topImpurity : Double ): InformationGainStats = {
680682 strategy.algo match {
681683 case Classification =>
682684 // TODO: Modify here
683- val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
684- val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
685+ val left0Count = leftNodeAgg(featureIndex)(splitIndex)( 0 )
686+ val left1Count = leftNodeAgg(featureIndex)(splitIndex)( 1 )
685687 val leftCount = left0Count + left1Count
686688
687- val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
688- val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1 )
689+ val right0Count = rightNodeAgg(featureIndex)(splitIndex)( 0 )
690+ val right1Count = rightNodeAgg(featureIndex)(splitIndex)( 1 )
689691 val rightCount = right0Count + right1Count
690692
691693 val impurity = {
@@ -722,13 +724,13 @@ object DecisionTree extends Serializable with Logging {
722724
723725 new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict)
724726 case Regression =>
725- val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
726- val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1 )
727- val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2 )
727+ val leftCount = leftNodeAgg(featureIndex)(splitIndex)( 0 )
728+ val leftSum = leftNodeAgg(featureIndex)(splitIndex)( 1 )
729+ val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)( 2 )
728730
729- val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
730- val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1 )
731- val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2 )
731+ val rightCount = rightNodeAgg(featureIndex)(splitIndex)( 0 )
732+ val rightSum = rightNodeAgg(featureIndex)(splitIndex)( 1 )
733+ val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)( 2 )
732734
733735 val impurity = {
734736 if (level > 0 ) {
@@ -777,98 +779,121 @@ object DecisionTree extends Serializable with Logging {
777779 * Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
778780 */
779781 def extractLeftRightNodeAggregates (
780- binData : Array [Double ]): (Array [Array [Double ]], Array [Array [Double ]]) = {
782+ binData : Array [Double ]): (Array [Array [Array [ Double ]]] , Array [Array [Array [ Double ] ]]) = {
781783 strategy.algo match {
782784 case Classification =>
783785 // TODO: Multiclass modification here
784- // Initialize left and right split aggregates.
785- val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numBins - 1 ))
786- val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numBins - 1 ))
787- // Iterate over all features.
788- var featureIndex = 0
789- while (featureIndex < numFeatures) {
790- // shift for this featureIndex
791- val shift = 2 * featureIndex * numBins
792-
793- // left node aggregate for the lowest split
794- leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
795- leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
796-
797- // right node aggregate for the highest split
798- rightNodeAgg(featureIndex)(2 * (numBins - 2 ))
799- = binData(shift + (2 * (numBins - 1 )))
800- rightNodeAgg(featureIndex)(2 * (numBins - 2 ) + 1 )
801- = binData(shift + (2 * (numBins - 1 )) + 1 )
802-
803- // Iterate over all splits.
804- var splitIndex = 1
805- while (splitIndex < numBins - 1 ) {
806- // calculating left node aggregate for a split as a sum of left node aggregate of a
807- // lower split and the left bin aggregate of a bin where the split is a high split
808- leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
809- leftNodeAgg(featureIndex)(2 * splitIndex - 2 )
810- leftNodeAgg(featureIndex)(2 * splitIndex + 1 ) = binData(shift + 2 * splitIndex + 1 ) +
811- leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1 )
812786
813- // calculating right node aggregate for a split as a sum of right node aggregate of a
814- // higher split and the right bin aggregate of a bin where the split is a low split
815- rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
816- binData(shift + (2 * (numBins - 2 - splitIndex))) +
817- rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
818- rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1 ) =
819- binData(shift + (2 * (numBins - 2 - splitIndex) + 1 )) +
820- rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1 )
821-
822- splitIndex += 1
787+ // Initialize left and right split aggregates.
788+ val leftNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
789+ val rightNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
790+
791+ if (strategy.isMultiClassification) {
792+ var featureIndex = 0
793+ while (featureIndex < numFeatures){
794+ val numCategories = strategy.categoricalFeaturesInfo(featureIndex)
795+ val maxSplits = math.pow(2 , numCategories) - 1
796+ var i = 0
797+ // TODO: Add multiclass case here
798+ while (i < maxSplits) {
799+ var classIndex = 0
800+ while (classIndex < numClasses) {
801+ // shift for this featureIndex
802+ val shift = numClasses * featureIndex * numBins
803+
804+ classIndex += 1
805+ }
806+ i += 1
807+ }
808+ featureIndex += 1
809+ }
810+ } else {
811+ // Iterate over all features.
812+ var featureIndex = 0
813+ while (featureIndex < numFeatures) {
814+ // shift for this featureIndex
815+ val shift = 2 * featureIndex * numBins
816+
817+ // left node aggregate for the lowest split
818+ leftNodeAgg(featureIndex)(0 )(0 ) = binData(shift + 0 )
819+ leftNodeAgg(featureIndex)(0 )(1 ) = binData(shift + 1 )
820+
821+ // right node aggregate for the highest split
822+ rightNodeAgg(featureIndex)(numBins - 2 )(0 )
823+ = binData(shift + (2 * (numBins - 1 )))
824+ rightNodeAgg(featureIndex)(numBins - 2 )(1 )
825+ = binData(shift + (2 * (numBins - 1 )) + 1 )
826+
827+ // Iterate over all splits.
828+ var splitIndex = 1
829+ while (splitIndex < numBins - 1 ) {
830+ // calculating left node aggregate for a split as a sum of left node aggregate of a
831+ // lower split and the left bin aggregate of a bin where the split is a high split
832+ leftNodeAgg(featureIndex)(splitIndex)(0 ) = binData(shift + 2 * splitIndex) +
833+ leftNodeAgg(featureIndex)(splitIndex - 1 )(0 )
834+ leftNodeAgg(featureIndex)(splitIndex)(1 ) = binData(shift + 2 * splitIndex +
835+ 1 ) + leftNodeAgg(featureIndex)(splitIndex - 1 )(1 )
836+
837+ // calculating right node aggregate for a split as a sum of right node aggregate of a
838+ // higher split and the right bin aggregate of a bin where the split is a low split
839+ rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0 ) =
840+ binData(shift + (2 * (numBins - 2 - splitIndex))) +
841+ rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0 )
842+ rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1 ) =
843+ binData(shift + (2 * (numBins - 2 - splitIndex) + 1 )) +
844+ rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1 )
845+
846+ splitIndex += 1
847+ }
848+ featureIndex += 1
823849 }
824- featureIndex += 1
825850 }
826851 (leftNodeAgg, rightNodeAgg)
827852 case Regression =>
828853 // Initialize left and right split aggregates.
829- val leftNodeAgg = Array .ofDim[Double ](numFeatures, 3 * ( numBins - 1 ) )
830- val rightNodeAgg = Array .ofDim[Double ](numFeatures, 3 * ( numBins - 1 ) )
854+ val leftNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , 3 )
855+ val rightNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , 3 )
831856 // Iterate over all features.
832857 var featureIndex = 0
833858 while (featureIndex < numFeatures) {
834859 // shift for this featureIndex
835860 val shift = 3 * featureIndex * numBins
836861 // left node aggregate for the lowest split
837- leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
838- leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
839- leftNodeAgg(featureIndex)(2 ) = binData(shift + 2 )
862+ leftNodeAgg(featureIndex)(0 )( 0 ) = binData(shift + 0 )
863+ leftNodeAgg(featureIndex)(0 )( 1 ) = binData(shift + 1 )
864+ leftNodeAgg(featureIndex)(0 )( 2 ) = binData(shift + 2 )
840865
841866 // right node aggregate for the highest split
842- rightNodeAgg(featureIndex)(3 * ( numBins - 2 )) =
867+ rightNodeAgg(featureIndex)(numBins - 2 )( 0 ) =
843868 binData(shift + (3 * (numBins - 1 )))
844- rightNodeAgg(featureIndex)(3 * ( numBins - 2 ) + 1 ) =
869+ rightNodeAgg(featureIndex)(numBins - 2 )( 1 ) =
845870 binData(shift + (3 * (numBins - 1 )) + 1 )
846- rightNodeAgg(featureIndex)(3 * ( numBins - 2 ) + 2 ) =
871+ rightNodeAgg(featureIndex)(numBins - 2 )( 2 ) =
847872 binData(shift + (3 * (numBins - 1 )) + 2 )
848873
849874 // Iterate over all splits.
850875 var splitIndex = 1
851876 while (splitIndex < numBins - 1 ) {
852877 // calculating left node aggregate for a split as a sum of left node aggregate of a
853878 // lower split and the left bin aggregate of a bin where the split is a high split
854- leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
855- leftNodeAgg(featureIndex)(3 * splitIndex - 3 )
856- leftNodeAgg(featureIndex)(3 * splitIndex + 1 ) = binData(shift + 3 * splitIndex + 1 ) +
857- leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1 )
858- leftNodeAgg(featureIndex)(3 * splitIndex + 2 ) = binData(shift + 3 * splitIndex + 2 ) +
859- leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2 )
879+ leftNodeAgg(featureIndex)(splitIndex)( 0 ) = binData(shift + 3 * splitIndex) +
880+ leftNodeAgg(featureIndex)(splitIndex - 1 )( 0 )
881+ leftNodeAgg(featureIndex)(splitIndex)( 1 ) = binData(shift + 3 * splitIndex + 1 ) +
882+ leftNodeAgg(featureIndex)(splitIndex - 1 )( 1 )
883+ leftNodeAgg(featureIndex)(splitIndex)( 2 ) = binData(shift + 3 * splitIndex + 2 ) +
884+ leftNodeAgg(featureIndex)(splitIndex - 1 )( 2 )
860885
861886 // calculating right node aggregate for a split as a sum of right node aggregate of a
862887 // higher split and the right bin aggregate of a bin where the split is a low split
863- rightNodeAgg(featureIndex)(3 * ( numBins - 2 - splitIndex)) =
888+ rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)( 0 ) =
864889 binData(shift + (3 * (numBins - 2 - splitIndex))) +
865- rightNodeAgg(featureIndex)(3 * ( numBins - 1 - splitIndex))
866- rightNodeAgg(featureIndex)(3 * ( numBins - 2 - splitIndex) + 1 ) =
890+ rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)( 0 )
891+ rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)( 1 ) =
867892 binData(shift + (3 * (numBins - 2 - splitIndex) + 1 )) +
868- rightNodeAgg(featureIndex)(3 * ( numBins - 1 - splitIndex) + 1 )
869- rightNodeAgg(featureIndex)(3 * ( numBins - 2 - splitIndex) + 2 ) =
893+ rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)( 1 )
894+ rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)( 2 ) =
870895 binData(shift + (3 * (numBins - 2 - splitIndex) + 2 )) +
871- rightNodeAgg(featureIndex)(3 * ( numBins - 1 - splitIndex) + 2 )
896+ rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)( 2 )
872897
873898 splitIndex += 1
874899 }
@@ -882,8 +907,8 @@ object DecisionTree extends Serializable with Logging {
882907 * Calculates information gain for all nodes splits.
883908 */
884909 def calculateGainsForAllNodeSplits (
885- leftNodeAgg : Array [Array [Double ]],
886- rightNodeAgg : Array [Array [Double ]],
910+ leftNodeAgg : Array [Array [Array [ Double ] ]],
911+ rightNodeAgg : Array [Array [Array [ Double ] ]],
887912 nodeImpurity : Double ): Array [Array [InformationGainStats ]] = {
888913 val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
889914
0 commit comments