Skip to content

Commit 6c7af22

Browse files
committed
prepared for multiclass without breaking binary classification
1 parent 46e06ee commit 6c7af22

File tree

1 file changed

+107
-82
lines changed

1 file changed

+107
-82
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 107 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,8 @@ object DecisionTree extends Serializable with Logging {
385385
logDebug("numFeatures = " + numFeatures)
386386
val numBins = bins(0).length
387387
logDebug("numBins = " + numBins)
388+
val numClasses = strategy.numClassesForClassification
389+
logDebug("numClasses = " + numClasses)
388390

389391
// shift when more than one group is used at deep tree level
390392
val groupShift = numNodes * groupIndex
@@ -545,10 +547,10 @@ object DecisionTree extends Serializable with Logging {
545547
* incremented based upon whether the feature is classified as 0 or 1.
546548
*
547549
* @param agg Array[Double] storing aggregate calculation of size
548-
* 2 * numSplits * numFeatures*numNodes for classification
550+
* numClasses * numSplits * numFeatures*numNodes for classification
549551
* @param arr Array[Double] of size 1 + (numFeatures * numNodes)
550552
* @return Array[Double] storing aggregate calculation of size
551-
* 2 * numSplits * numFeatures * numNodes for classification
553+
* numClasses * numSplits * numFeatures * numNodes for classification
552554
*/
553555
def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
554556
// Iterate over all nodes.
@@ -562,16 +564,16 @@ object DecisionTree extends Serializable with Logging {
562564
val label = arr(0)
563565
// Iterate over all features.
564566
var featureIndex = 0
565-
// TODO: Multiclass modification here
566567
while (featureIndex < numFeatures) {
567568
// Find the bin index for this feature.
568569
val arrShift = 1 + numFeatures * nodeIndex
569570
val arrIndex = arrShift + featureIndex
570571
// Update the left or right count for one bin.
571-
val aggShift = 2 * numBins * numFeatures * nodeIndex
572-
val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
573-
label match {
574-
case n: Double => agg(aggIndex) = agg(aggIndex + n.toInt) + 1
572+
val aggShift = numClasses * numBins * numFeatures * nodeIndex
573+
val aggIndex
574+
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
575+
label.toInt match {
576+
case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + 1
575577
}
576578
featureIndex += 1
577579
}
@@ -632,7 +634,7 @@ object DecisionTree extends Serializable with Logging {
632634

633635
// Calculate bin aggregate length for classification or regression.
634636
val binAggregateLength = strategy.algo match {
635-
case Classification => 2 * numBins * numFeatures * numNodes
637+
case Classification => numClasses * numBins * numFeatures * numNodes
636638
case Regression => 3 * numBins * numFeatures * numNodes
637639
}
638640
logDebug("binAggregateLength = " + binAggregateLength)
@@ -672,20 +674,20 @@ object DecisionTree extends Serializable with Logging {
672674
* @return information gain and statistics for all splits
673675
*/
674676
def calculateGainForSplit(
675-
leftNodeAgg: Array[Array[Double]],
677+
leftNodeAgg: Array[Array[Array[Double]]],
676678
featureIndex: Int,
677679
splitIndex: Int,
678-
rightNodeAgg: Array[Array[Double]],
680+
rightNodeAgg: Array[Array[Array[Double]]],
679681
topImpurity: Double): InformationGainStats = {
680682
strategy.algo match {
681683
case Classification =>
682684
// TODO: Modify here
683-
val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
684-
val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
685+
val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0)
686+
val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1)
685687
val leftCount = left0Count + left1Count
686688

687-
val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
688-
val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
689+
val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0)
690+
val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1)
689691
val rightCount = right0Count + right1Count
690692

691693
val impurity = {
@@ -722,13 +724,13 @@ object DecisionTree extends Serializable with Logging {
722724

723725
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
724726
case Regression =>
725-
val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
726-
val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
727-
val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
727+
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
728+
val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
729+
val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)
728730

729-
val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
730-
val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
731-
val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
731+
val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
732+
val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
733+
val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)
732734

733735
val impurity = {
734736
if (level > 0) {
@@ -777,98 +779,121 @@ object DecisionTree extends Serializable with Logging {
777779
* Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
778780
*/
779781
def extractLeftRightNodeAggregates(
780-
binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
782+
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
781783
strategy.algo match {
782784
case Classification =>
783785
// TODO: Multiclass modification here
784-
// Initialize left and right split aggregates.
785-
val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
786-
val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
787-
// Iterate over all features.
788-
var featureIndex = 0
789-
while (featureIndex < numFeatures) {
790-
// shift for this featureIndex
791-
val shift = 2 * featureIndex * numBins
792-
793-
// left node aggregate for the lowest split
794-
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
795-
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
796-
797-
// right node aggregate for the highest split
798-
rightNodeAgg(featureIndex)(2 * (numBins - 2))
799-
= binData(shift + (2 * (numBins - 1)))
800-
rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1)
801-
= binData(shift + (2 * (numBins - 1)) + 1)
802-
803-
// Iterate over all splits.
804-
var splitIndex = 1
805-
while (splitIndex < numBins - 1) {
806-
// calculating left node aggregate for a split as a sum of left node aggregate of a
807-
// lower split and the left bin aggregate of a bin where the split is a high split
808-
leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
809-
leftNodeAgg(featureIndex)(2 * splitIndex - 2)
810-
leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) +
811-
leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
812786

813-
// calculating right node aggregate for a split as a sum of right node aggregate of a
814-
// higher split and the right bin aggregate of a bin where the split is a low split
815-
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
816-
binData(shift + (2 *(numBins - 2 - splitIndex))) +
817-
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
818-
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
819-
binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
820-
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
821-
822-
splitIndex += 1
787+
// Initialize left and right split aggregates.
788+
val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
789+
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
790+
791+
if (strategy.isMultiClassification) {
792+
var featureIndex = 0
793+
while (featureIndex < numFeatures){
794+
val numCategories = strategy.categoricalFeaturesInfo(featureIndex)
795+
val maxSplits = math.pow(2, numCategories) - 1
796+
var i = 0
797+
// TODO: Add multiclass case here
798+
while (i < maxSplits) {
799+
var classIndex = 0
800+
while (classIndex < numClasses) {
801+
// shift for this featureIndex
802+
val shift = numClasses * featureIndex * numBins
803+
804+
classIndex += 1
805+
}
806+
i += 1
807+
}
808+
featureIndex += 1
809+
}
810+
} else {
811+
// Iterate over all features.
812+
var featureIndex = 0
813+
while (featureIndex < numFeatures) {
814+
// shift for this featureIndex
815+
val shift = 2 * featureIndex * numBins
816+
817+
// left node aggregate for the lowest split
818+
leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
819+
leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
820+
821+
// right node aggregate for the highest split
822+
rightNodeAgg(featureIndex)(numBins - 2)(0)
823+
= binData(shift + (2 * (numBins - 1)))
824+
rightNodeAgg(featureIndex)(numBins - 2)(1)
825+
= binData(shift + (2 * (numBins - 1)) + 1)
826+
827+
// Iterate over all splits.
828+
var splitIndex = 1
829+
while (splitIndex < numBins - 1) {
830+
// calculating left node aggregate for a split as a sum of left node aggregate of a
831+
// lower split and the left bin aggregate of a bin where the split is a high split
832+
leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) +
833+
leftNodeAgg(featureIndex)(splitIndex - 1)(0)
834+
leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex +
835+
1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1)
836+
837+
// calculating right node aggregate for a split as a sum of right node aggregate of a
838+
// higher split and the right bin aggregate of a bin where the split is a low split
839+
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) =
840+
binData(shift + (2 *(numBins - 2 - splitIndex))) +
841+
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0)
842+
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) =
843+
binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
844+
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1)
845+
846+
splitIndex += 1
847+
}
848+
featureIndex += 1
823849
}
824-
featureIndex += 1
825850
}
826851
(leftNodeAgg, rightNodeAgg)
827852
case Regression =>
828853
// Initialize left and right split aggregates.
829-
val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
830-
val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
854+
val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
855+
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
831856
// Iterate over all features.
832857
var featureIndex = 0
833858
while (featureIndex < numFeatures) {
834859
// shift for this featureIndex
835860
val shift = 3 * featureIndex * numBins
836861
// left node aggregate for the lowest split
837-
leftNodeAgg(featureIndex)(0) = binData(shift + 0)
838-
leftNodeAgg(featureIndex)(1) = binData(shift + 1)
839-
leftNodeAgg(featureIndex)(2) = binData(shift + 2)
862+
leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
863+
leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
864+
leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)
840865

841866
// right node aggregate for the highest split
842-
rightNodeAgg(featureIndex)(3 * (numBins - 2)) =
867+
rightNodeAgg(featureIndex)(numBins - 2)(0) =
843868
binData(shift + (3 * (numBins - 1)))
844-
rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) =
869+
rightNodeAgg(featureIndex)(numBins - 2)(1) =
845870
binData(shift + (3 * (numBins - 1)) + 1)
846-
rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) =
871+
rightNodeAgg(featureIndex)(numBins - 2)(2) =
847872
binData(shift + (3 * (numBins - 1)) + 2)
848873

849874
// Iterate over all splits.
850875
var splitIndex = 1
851876
while (splitIndex < numBins - 1) {
852877
// calculating left node aggregate for a split as a sum of left node aggregate of a
853878
// lower split and the left bin aggregate of a bin where the split is a high split
854-
leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
855-
leftNodeAgg(featureIndex)(3 * splitIndex - 3)
856-
leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) +
857-
leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
858-
leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) +
859-
leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
879+
leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) +
880+
leftNodeAgg(featureIndex)(splitIndex - 1)(0)
881+
leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) +
882+
leftNodeAgg(featureIndex)(splitIndex - 1)(1)
883+
leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) +
884+
leftNodeAgg(featureIndex)(splitIndex - 1)(2)
860885

861886
// calculating right node aggregate for a split as a sum of right node aggregate of a
862887
// higher split and the right bin aggregate of a bin where the split is a low split
863-
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
888+
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) =
864889
binData(shift + (3 * (numBins - 2 - splitIndex))) +
865-
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
866-
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
890+
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0)
891+
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) =
867892
binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
868-
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
869-
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
893+
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1)
894+
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) =
870895
binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
871-
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
896+
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2)
872897

873898
splitIndex += 1
874899
}
@@ -882,8 +907,8 @@ object DecisionTree extends Serializable with Logging {
882907
* Calculates information gain for all nodes splits.
883908
*/
884909
def calculateGainsForAllNodeSplits(
885-
leftNodeAgg: Array[Array[Double]],
886-
rightNodeAgg: Array[Array[Double]],
910+
leftNodeAgg: Array[Array[Array[Double]]],
911+
rightNodeAgg: Array[Array[Array[Double]]],
887912
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
888913
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
889914

0 commit comments

Comments
 (0)