@@ -78,11 +78,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7878 // Max memory usage for aggregates
7979 val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
8080 logDebug(" max memory usage for aggregates = " + maxMemoryUsage + " bytes." )
81- val numElementsPerNode =
82- strategy.algo match {
83- case Classification => 2 * numBins * numFeatures
84- case Regression => 3 * numBins * numFeatures
85- }
81+ val numElementsPerNode = DecisionTree .getElementsPerNode(numFeatures, numBins,
82+ strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
83+ strategy.algo)
8684
8785 logDebug(" numElementsPerNode = " + numElementsPerNode)
8886 val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -144,7 +142,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
144142 new DecisionTreeModel (topNode, strategy.algo)
145143 }
146144
147- // TODO: Unit test this
148145 /**
149146 * Extract the decision tree node information for the given tree level and node index
150147 */
@@ -162,7 +159,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
162159 nodes(nodeIndex) = node
163160 }
164161
165- // TODO: Unit test this
166162 /**
167163 * Extract the decision tree node information for the children of the node
168164 */
@@ -290,12 +286,12 @@ object DecisionTree extends Serializable with Logging {
290286 * @return a DecisionTreeModel that can be used for prediction
291287 */
292288 def train (
293- input : RDD [LabeledPoint ],
294- algo : Algo ,
295- impurity : Impurity ,
296- maxDepth : Int ,
297- numClassesForClassification : Int ,
298- labelWeights : Map [Int ,Int ]): DecisionTreeModel = {
289+ input : RDD [LabeledPoint ],
290+ algo : Algo ,
291+ impurity : Impurity ,
292+ maxDepth : Int ,
293+ numClassesForClassification : Int ,
294+ labelWeights : Map [Int ,Int ]): DecisionTreeModel = {
299295 val strategy
300296 = new Strategy (algo, impurity, maxDepth, numClassesForClassification,
301297 labelWeights = labelWeights)
@@ -462,7 +458,9 @@ object DecisionTree extends Serializable with Logging {
462458 logDebug(" labelWeights = " + labelWeights)
463459 val isMulticlassClassification = strategy.isMulticlassClassification
464460 logDebug(" isMulticlassClassification = " + isMulticlassClassification)
465-
461+ val isMulticlassClassificationWithCategoricalFeatures
462+ = strategy.isMulticlassWithCategoricalFeatures
463+ logDebug(" isMultiClassWithCategoricalFeatures = " + isMulticlassClassificationWithCategoricalFeatures)
466464
467465 // shift when more than one group is used at deep tree level
468466 val groupShift = numNodes * groupIndex
@@ -518,9 +516,7 @@ object DecisionTree extends Serializable with Logging {
518516 /**
519517 * Find bin for one feature.
520518 */
521- def findBin (
522- featureIndex : Int ,
523- labeledPoint : WeightedLabeledPoint ,
519+ def findBin (featureIndex : Int , labeledPoint : WeightedLabeledPoint ,
524520 isFeatureContinuous : Boolean ): Int = {
525521 val binForFeatures = bins(featureIndex)
526522 val feature = labeledPoint.features(featureIndex)
@@ -636,9 +632,48 @@ object DecisionTree extends Serializable with Logging {
636632 }
637633
638634 // Find feature bins for all nodes at a level.
639- val binMappedRDD = input.map(x => findBinsForLevel(x))
635+ val binMappedRDD = input.map(x => findBinsForLevel(x))
636+
637+ def updateBinForOrderedFeature (arr : Array [Double ], agg : Array [Double ], nodeIndex : Int ,
638+ label : Double , featureIndex : Int ) = {
639+
640+ // Find the bin index for this feature.
641+ val arrShift = 1 + numFeatures * nodeIndex
642+ val arrIndex = arrShift + featureIndex
643+ // Update the left or right count for one bin.
644+ val aggShift = numClasses * numBins * numFeatures * nodeIndex
645+ val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
646+ val labelInt = label.toInt
647+ agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1 )
648+ }
640649
641- /**
650+ def updateBinForUnorderedFeature (nodeIndex : Int , featureIndex : Int , arr : Array [Double ],
651+ label : Double , agg : Array [Double ], rightChildShift : Int ) = {
652+ // Find the bin index for this feature.
653+ val arrShift = 1 + numFeatures * nodeIndex
654+ val arrIndex = arrShift + featureIndex
655+ // Update the left or right count for one bin.
656+ val aggShift = numClasses * numBins * numFeatures * nodeIndex
657+ val aggIndex
658+ = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
659+ // Find all matching bins and increment their values
660+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
661+ val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
662+ var binIndex = 0
663+ while (binIndex < numCategoricalBins) {
664+ val labelInt = label.toInt
665+ if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
666+ agg(aggIndex + binIndex)
667+ = agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1 )
668+ } else {
669+ agg(rightChildShift + aggIndex + binIndex)
670+ = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1 )
671+ }
672+ binIndex += 1
673+ }
674+ }
675+
676+ /**
642677 * Performs a sequential aggregation over a partition for classification. For l nodes,
643678 * k features, either the left count or the right count of one of the p bins is
644679 * incremented based upon whether the feature is classified as 0 or 1.
@@ -649,7 +684,7 @@ object DecisionTree extends Serializable with Logging {
649684 * @return Array[Double] storing aggregate calculation of size
650685 * 2 * numSplits * numFeatures * numNodes for classification
651686 */
652- def binaryClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) {
687+ def binaryClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
653688 // Iterate over all nodes.
654689 var nodeIndex = 0
655690 while (nodeIndex < numNodes) {
@@ -662,93 +697,51 @@ object DecisionTree extends Serializable with Logging {
662697 // Iterate over all features.
663698 var featureIndex = 0
664699 while (featureIndex < numFeatures) {
665- // Find the bin index for this feature.
666- val arrShift = 1 + numFeatures * nodeIndex
667- val arrIndex = arrShift + featureIndex
668- // Update the left or right count for one bin.
669- val aggShift = 2 * numBins * numFeatures * nodeIndex
670- val aggIndex
671- = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
672- label.toInt match {
673- case n : Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1 )
674- }
700+ updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
675701 featureIndex += 1
676702 }
677703 }
678704 nodeIndex += 1
679705 }
680706 }
681707
682- /**
683- * Performs a sequential aggregation over a partition for classification. For l nodes,
684- * k features, either the left count or the right count of one of the p bins is
685- * incremented based upon whether the feature is classified as 0 or 1.
686- *
687- * @param agg Array[Double] storing aggregate calculation of size
688- * numClasses * numSplits * numFeatures*numNodes for classification
689- * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
690- * @return Array[Double] storing aggregate calculation of size
691- * 2 * numClasses * numSplits * numFeatures * numNodes for classification
692- */
693- def multiClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) {
694- // Iterate over all nodes.
695- var nodeIndex = 0
696- while (nodeIndex < numNodes) {
697- // Check whether the instance was valid for this nodeIndex.
698- val validSignalIndex = 1 + numFeatures * nodeIndex
699- val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
700- if (isSampleValidForNode) {
701- val rightChildShift = numClasses * numBins * numFeatures * numNodes
702- // actual class label
703- val label = arr(0 )
704- // Iterate over all features.
705- var featureIndex = 0
706- while (featureIndex < numFeatures) {
707- val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
708- if (isContinuousFeature) {
709- // Find the bin index for this feature.
710- val arrShift = 1 + numFeatures * nodeIndex
711- val arrIndex = arrShift + featureIndex
712- // Update the left or right count for one bin.
713- val aggShift = numClasses * numBins * numFeatures * nodeIndex
714- val aggIndex
715- = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
716- label.toInt match {
717- case n : Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1 )
718- }
719- } else {
720- // Find the bin index for this feature.
721- val arrShift = 1 + numFeatures * nodeIndex
722- val arrIndex = arrShift + featureIndex
723- // Update the left or right count for one bin.
724- val aggShift = numClasses * numBins * numFeatures * nodeIndex
725- val aggIndex
726- = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
727- label.toInt match {
728- case n : Int =>
729- // Find all matching bins and increment their values
730- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
731- val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
732- var binIndex = 0
733- while (binIndex < numCategoricalBins) {
734- if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)) {
735- agg(aggIndex + binIndex)
736- = agg(aggIndex + binIndex) + labelWeights.getOrElse(n, 1 )
737- } else {
738- agg(rightChildShift + aggIndex + binIndex)
739- = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(n, 1 )
740-
741- }
742- binIndex += 1
743- }
744- }
745- }
746- featureIndex += 1
747- }
708+ /**
709+ * Performs a sequential aggregation over a partition for classification. For l nodes,
710+ * k features, either the left count or the right count of one of the p bins is
711+ * incremented based upon whether the feature is classified as 0 or 1.
712+ *
713+ * @param agg Array[Double] storing aggregate calculation of size
714+ * numClasses * numSplits * numFeatures*numNodes for classification
715+ * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
716+ * @return Array[Double] storing aggregate calculation of size
717+ * 2 * numClasses * numSplits * numFeatures * numNodes for classification
718+ */
719+ def multiClassificationBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
720+ // Iterate over all nodes.
721+ var nodeIndex = 0
722+ while (nodeIndex < numNodes) {
723+ // Check whether the instance was valid for this nodeIndex.
724+ val validSignalIndex = 1 + numFeatures * nodeIndex
725+ val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
726+ if (isSampleValidForNode) {
727+ val rightChildShift = numClasses * numBins * numFeatures * numNodes
728+ // actual class label
729+ val label = arr(0 )
730+ // Iterate over all features.
731+ var featureIndex = 0
732+ while (featureIndex < numFeatures) {
733+ val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
734+ if (isContinuousFeature) {
735+ updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
736+ } else {
737+ updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
738+ }
739+ featureIndex += 1
748740 }
749- nodeIndex += 1
750741 }
742+ nodeIndex += 1
751743 }
744+ }
752745
753746 /**
754747 * Performs a sequential aggregation over a partition for regression. For l nodes, k features,
@@ -760,7 +753,7 @@ object DecisionTree extends Serializable with Logging {
760753 * @return Array[Double] storing aggregate calculation of size
761754 * 3 * numSplits * numFeatures * numNodes for regression
762755 */
763- def regressionBinSeqOp (arr : Array [Double ], agg : Array [Double ]) {
756+ def regressionBinSeqOp (arr : Array [Double ], agg : Array [Double ]) = {
764757 // Iterate over all nodes.
765758 var nodeIndex = 0
766759 while (nodeIndex < numNodes) {
@@ -795,7 +788,7 @@ object DecisionTree extends Serializable with Logging {
795788 def binSeqOp (agg : Array [Double ], arr : Array [Double ]): Array [Double ] = {
796789 strategy.algo match {
797790 case Classification =>
798- if (isMulticlassClassification ) {
791+ if (isMulticlassClassificationWithCategoricalFeatures ) {
799792 multiClassificationBinSeqOp(arr, agg)
800793 } else {
801794 binaryClassificationBinSeqOp(arr, agg)
@@ -806,15 +799,8 @@ object DecisionTree extends Serializable with Logging {
806799 }
807800
808801 // Calculate bin aggregate length for classification or regression.
809- val binAggregateLength = strategy.algo match {
810- case Classification =>
811- if (isMulticlassClassification){
812- 2 * numClasses * numBins * numFeatures * numNodes
813- } else {
814- 2 * numBins * numFeatures * numNodes
815- }
816- case Regression => 3 * numBins * numFeatures * numNodes
817- }
802+ val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
803+ isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
818804 logDebug(" binAggregateLength = " + binAggregateLength)
819805
820806 /**
@@ -1024,7 +1010,7 @@ object DecisionTree extends Serializable with Logging {
10241010 }
10251011 }
10261012
1027- def findAggregateForCategoricalFeatureClassification (
1013+ def findAggForUnorderedFeatureClassification (
10281014 leftNodeAgg : Array [Array [Array [Double ]]],
10291015 rightNodeAgg : Array [Array [Array [Double ]]],
10301016 featureIndex : Int ) {
@@ -1101,12 +1087,12 @@ object DecisionTree extends Serializable with Logging {
11011087 val rightNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
11021088 var featureIndex = 0
11031089 while (featureIndex < numFeatures) {
1104- if (isMulticlassClassification ){
1090+ if (isMulticlassClassificationWithCategoricalFeatures ){
11051091 val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
11061092 if (isFeatureContinuous) {
11071093 findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
11081094 } else {
1109- findAggregateForCategoricalFeatureClassification (leftNodeAgg, rightNodeAgg, featureIndex)
1095+ findAggForUnorderedFeatureClassification (leftNodeAgg, rightNodeAgg, featureIndex)
11101096 }
11111097 } else {
11121098 findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
@@ -1214,7 +1200,7 @@ object DecisionTree extends Serializable with Logging {
12141200 def getBinDataForNode (node : Int ): Array [Double ] = {
12151201 strategy.algo match {
12161202 case Classification =>
1217- if (isMulticlassClassification ) {
1203+ if (isMulticlassClassificationWithCategoricalFeatures ) {
12181204 val shift = numClasses * node * numBins * numFeatures
12191205 val rightChildShift = numClasses * numBins * numFeatures * numNodes
12201206 val binsForNode = {
@@ -1251,10 +1237,22 @@ object DecisionTree extends Serializable with Logging {
12511237 bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
12521238 node += 1
12531239 }
1254-
12551240 bestSplits
12561241 }
12571242
1243+ private def getElementsPerNode (numFeatures : Int , numBins : Int , numClasses : Int ,
1244+ isMulticlassClassificationWithCategoricalFeatures : Boolean , algo : Algo ): Int = {
1245+ algo match {
1246+ case Classification =>
1247+ if (isMulticlassClassificationWithCategoricalFeatures) {
1248+ 2 * numClasses * numBins * numFeatures
1249+ } else {
1250+ numClasses * numBins * numFeatures
1251+ }
1252+ case Regression => 3 * numBins * numFeatures
1253+ }
1254+ }
1255+
12581256 /**
12591257 * Returns split and bins for decision tree calculation.
12601258 * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
@@ -1288,9 +1286,12 @@ object DecisionTree extends Serializable with Logging {
12881286 */
12891287 if (strategy.categoricalFeaturesInfo.size > 0 ) {
12901288 val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
1291- require(numBins > maxCategoriesForFeatures)
1289+ require(numBins > maxCategoriesForFeatures, " numBins should be greater than max categories " +
1290+ " in categorical features" )
12921291 if (isMulticlassClassification) {
1293- require(numBins > math.pow(2 , maxCategoriesForFeatures.toInt - 1 ) - 1 )
1292+ require(numBins > math.pow(2 , maxCategoriesForFeatures.toInt - 1 ) - 1 ,
1293+ " numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" +
1294+ " with categorical variables" )
12941295 }
12951296 }
12961297
@@ -1331,7 +1332,8 @@ object DecisionTree extends Serializable with Logging {
13311332 } else { // Categorical feature
13321333 val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
13331334
1334- // Use different bin/split calculation strategy for multiclass classification
1335+ // Use different bin/split calculation strategy for categorical features in multiclass
1336+ // classification
13351337 if (isMulticlassClassification) {
13361338 // 2^(maxFeatureValue- 1) - 1 combinations
13371339 var index = 0
0 commit comments