@@ -144,6 +144,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
144144 new DecisionTreeModel (topNode, strategy.algo)
145145 }
146146
147+ // TODO: Unit test this
147148 /**
148149 * Extract the decision tree node information for the given tree level and node index
149150 */
@@ -161,6 +162,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
161162 nodes(nodeIndex) = node
162163 }
163164
165+ // TODO: Unit test this
164166 /**
165167 * Extract the decision tree node information for the children of the node
166168 */
@@ -458,6 +460,8 @@ object DecisionTree extends Serializable with Logging {
458460 logDebug(" numClasses = " + numClasses)
459461 val labelWeights = strategy.labelWeights
460462 logDebug(" labelWeights = " + labelWeights)
463+ val isMulticlassClassification = strategy.isMulticlassClassification
464+ logDebug(" isMulticlassClassification = " + isMulticlassClassification)
461465
462466
463467 // shift when more than one group is used at deep tree level
@@ -582,7 +586,7 @@ object DecisionTree extends Serializable with Logging {
582586 } else {
583587 // Perform sequential search to find bin for categorical features.
584588 val binIndex = {
585- if (strategy.isMultiClassification ) {
589+ if (isMulticlassClassification ) {
586590 sequentialBinSearchForCategoricalFeatureInBinaryClassification()
587591 } else {
588592 sequentialBinSearchForCategoricalFeatureInMultiClassClassification()
@@ -606,7 +610,9 @@ object DecisionTree extends Serializable with Logging {
606610 def findBinsForLevel (labeledPoint : WeightedLabeledPoint ): Array [Double ] = {
607611 // Calculate bin index and label per feature per node.
608612 val arr = new Array [Double ](1 + (numFeatures * numNodes))
613+ // First element of the array is the label of the instance.
609614 arr(0 ) = labeledPoint.label
615+ // Iterate over nodes.
610616 var nodeIndex = 0
611617 while (nodeIndex < numNodes) {
612618 val parentFilters = findParentFilters(nodeIndex)
@@ -629,7 +635,10 @@ object DecisionTree extends Serializable with Logging {
629635 arr
630636 }
631637
632- /**
638+ // Find feature bins for all nodes at a level.
639+ val binMappedRDD = input.map(x => findBinsForLevel(x))
640+
641+ /**
633642 * Performs a sequential aggregation over a partition for classification. For l nodes,
634643 * k features, either the left count or the right count of one of the p bins is
635644 * incremented based upon whether the feature is classified as 0 or 1.
@@ -663,7 +672,7 @@ object DecisionTree extends Serializable with Logging {
663672 label.toInt match {
664673 case n : Int =>
665674 val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
666- if (! isFeatureContinuous && strategy.isMultiClassification ) {
675+ if (! isFeatureContinuous && isMulticlassClassification ) {
667676 // Find all matching bins and increment their values
668677 val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
669678 val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
@@ -736,7 +745,6 @@ object DecisionTree extends Serializable with Logging {
736745 agg
737746 }
738747
739- // TODO: Double-check this
740748 // Calculate bin aggregate length for classification or regression.
741749 val binAggregateLength = strategy.algo match {
742750 case Classification => numClasses * numBins * numFeatures * numNodes
@@ -760,9 +768,6 @@ object DecisionTree extends Serializable with Logging {
760768 combinedAggregate
761769 }
762770
763- // Find feature bins for all nodes at a level.
764- val binMappedRDD = input.map(x => findBinsForLevel(x))
765-
766771 // Calculate bin aggregates.
767772 val binAggregates = {
768773 binMappedRDD.aggregate(Array .fill[Double ](binAggregateLength)(0 ))(binSeqOp,binCombOp)
@@ -922,7 +927,7 @@ object DecisionTree extends Serializable with Logging {
922927 val leftNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
923928 val rightNodeAgg = Array .ofDim[Double ](numFeatures, numBins - 1 , numClasses)
924929
925- if (strategy.isMultiClassification ) {
930+ if (isMulticlassClassification ) {
926931 var featureIndex = 0
927932 while (featureIndex < numFeatures){
928933 var splitIndex = 0
@@ -1096,7 +1101,7 @@ object DecisionTree extends Serializable with Logging {
10961101 numBins - 1
10971102 } else { // Categorical feature
10981103 val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1099- if (strategy.isMultiClassification ) {
1104+ if (isMulticlassClassification ) {
11001105 math.pow(2.0 , featureCategories - 1 ).toInt - 1
11011106 } else { // Binary classification
11021107 featureCategories
@@ -1177,6 +1182,9 @@ object DecisionTree extends Serializable with Logging {
11771182 val maxBins = strategy.maxBins
11781183 val numBins = if (maxBins <= count) maxBins else count.toInt
11791184 logDebug(" numBins = " + numBins)
1185+ val isMulticlassClassification = strategy.isMulticlassClassification
1186+ logDebug(" isMulticlassClassification = " + isMulticlassClassification)
1187+
11801188
11811189 /*
11821190 * Ensure #bins is always greater than the categories. For multiclass classification,
@@ -1187,7 +1195,7 @@ object DecisionTree extends Serializable with Logging {
11871195 if (strategy.categoricalFeaturesInfo.size > 0 ) {
11881196 val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
11891197 require(numBins > maxCategoriesForFeatures)
1190- if (strategy.isMultiClassification ) {
1198+ if (isMulticlassClassification ) {
11911199 require(numBins > math.pow(2 , maxCategoriesForFeatures.toInt - 1 ) - 1 )
11921200 }
11931201 }
@@ -1230,7 +1238,7 @@ object DecisionTree extends Serializable with Logging {
12301238 val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
12311239
12321240 // Use different bin/split calculation strategy for multiclass classification
1233- if (strategy.isMultiClassification ) {
1241+ if (isMulticlassClassification ) {
12341242 // 2^(maxFeatureValue- 1) - 1 combinations
12351243 var index = 0
12361244 while (index < math.pow(2.0 , featureCategories - 1 ).toInt - 1 ) {
0 commit comments