Skip to content

Commit 1892a2c

Browse files
committed
tests and use multiclass binaggregate length when atleast one categorical feature is present
1 parent f5f6b83 commit 1892a2c

File tree

4 files changed

+139
-120
lines changed

4 files changed

+139
-120
lines changed

docs/mllib-decision-tree.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,16 @@ bins if the condition is not satisfied.
7676

7777
**Categorical features**
7878

79-
For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
80-
binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
79+
For `$M$` categorical features, one could come up with `$2^(M-1)-1$` split candidates. For
80+
binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the
8181
categorical feature values by the proportion of labels falling in one of the two classes (see
8282
Section 9.2.4 in
8383
[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
8484
details). For example, for a binary classification problem with one categorical feature with three
8585
categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
8686
features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
8787
and A , B \| C where \| denotes the split.
88-
88+
<!-- -->
8989
### Stopping rule
9090

9191
The recursive tree construction is stopped at a node when one of the two conditions is met:

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 118 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7878
// Max memory usage for aggregates
7979
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
8080
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
81-
val numElementsPerNode =
82-
strategy.algo match {
83-
case Classification => 2 * numBins * numFeatures
84-
case Regression => 3 * numBins * numFeatures
85-
}
81+
val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins,
82+
strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
83+
strategy.algo)
8684

8785
logDebug("numElementsPerNode = " + numElementsPerNode)
8886
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -144,7 +142,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
144142
new DecisionTreeModel(topNode, strategy.algo)
145143
}
146144

147-
// TODO: Unit test this
148145
/**
149146
* Extract the decision tree node information for the given tree level and node index
150147
*/
@@ -162,7 +159,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
162159
nodes(nodeIndex) = node
163160
}
164161

165-
// TODO: Unit test this
166162
/**
167163
* Extract the decision tree node information for the children of the node
168164
*/
@@ -290,12 +286,12 @@ object DecisionTree extends Serializable with Logging {
290286
* @return a DecisionTreeModel that can be used for prediction
291287
*/
292288
def train(
293-
input: RDD[LabeledPoint],
294-
algo: Algo,
295-
impurity: Impurity,
296-
maxDepth: Int,
297-
numClassesForClassification: Int,
298-
labelWeights: Map[Int,Int]): DecisionTreeModel = {
289+
input: RDD[LabeledPoint],
290+
algo: Algo,
291+
impurity: Impurity,
292+
maxDepth: Int,
293+
numClassesForClassification: Int,
294+
labelWeights: Map[Int,Int]): DecisionTreeModel = {
299295
val strategy
300296
= new Strategy(algo, impurity, maxDepth, numClassesForClassification,
301297
labelWeights = labelWeights)
@@ -462,7 +458,9 @@ object DecisionTree extends Serializable with Logging {
462458
logDebug("labelWeights = " + labelWeights)
463459
val isMulticlassClassification = strategy.isMulticlassClassification
464460
logDebug("isMulticlassClassification = " + isMulticlassClassification)
465-
461+
val isMulticlassClassificationWithCategoricalFeatures
462+
= strategy.isMulticlassWithCategoricalFeatures
463+
logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassClassificationWithCategoricalFeatures)
466464

467465
// shift when more than one group is used at deep tree level
468466
val groupShift = numNodes * groupIndex
@@ -518,9 +516,7 @@ object DecisionTree extends Serializable with Logging {
518516
/**
519517
* Find bin for one feature.
520518
*/
521-
def findBin(
522-
featureIndex: Int,
523-
labeledPoint: WeightedLabeledPoint,
519+
def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint,
524520
isFeatureContinuous: Boolean): Int = {
525521
val binForFeatures = bins(featureIndex)
526522
val feature = labeledPoint.features(featureIndex)
@@ -636,9 +632,48 @@ object DecisionTree extends Serializable with Logging {
636632
}
637633

638634
// Find feature bins for all nodes at a level.
639-
val binMappedRDD = input.map(x => findBinsForLevel(x))
635+
val binMappedRDD = input.map(x => findBinsForLevel(x))
636+
637+
def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int,
638+
label: Double, featureIndex: Int) = {
639+
640+
// Find the bin index for this feature.
641+
val arrShift = 1 + numFeatures * nodeIndex
642+
val arrIndex = arrShift + featureIndex
643+
// Update the left or right count for one bin.
644+
val aggShift = numClasses * numBins * numFeatures * nodeIndex
645+
val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
646+
val labelInt = label.toInt
647+
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1)
648+
}
640649

641-
/**
650+
def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
651+
label: Double, agg: Array[Double], rightChildShift: Int) = {
652+
// Find the bin index for this feature.
653+
val arrShift = 1 + numFeatures * nodeIndex
654+
val arrIndex = arrShift + featureIndex
655+
// Update the left or right count for one bin.
656+
val aggShift = numClasses * numBins * numFeatures * nodeIndex
657+
val aggIndex
658+
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
659+
// Find all matching bins and increment their values
660+
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
661+
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
662+
var binIndex = 0
663+
while (binIndex < numCategoricalBins) {
664+
val labelInt = label.toInt
665+
if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
666+
agg(aggIndex + binIndex)
667+
= agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1)
668+
} else {
669+
agg(rightChildShift + aggIndex + binIndex)
670+
= agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1)
671+
}
672+
binIndex += 1
673+
}
674+
}
675+
676+
/**
642677
* Performs a sequential aggregation over a partition for classification. For l nodes,
643678
* k features, either the left count or the right count of one of the p bins is
644679
* incremented based upon whether the feature is classified as 0 or 1.
@@ -649,7 +684,7 @@ object DecisionTree extends Serializable with Logging {
649684
* @return Array[Double] storing aggregate calculation of size
650685
* 2 * numSplits * numFeatures * numNodes for classification
651686
*/
652-
def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
687+
def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
653688
// Iterate over all nodes.
654689
var nodeIndex = 0
655690
while (nodeIndex < numNodes) {
@@ -662,93 +697,51 @@ object DecisionTree extends Serializable with Logging {
662697
// Iterate over all features.
663698
var featureIndex = 0
664699
while (featureIndex < numFeatures) {
665-
// Find the bin index for this feature.
666-
val arrShift = 1 + numFeatures * nodeIndex
667-
val arrIndex = arrShift + featureIndex
668-
// Update the left or right count for one bin.
669-
val aggShift = 2 * numBins * numFeatures * nodeIndex
670-
val aggIndex
671-
= aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
672-
label.toInt match {
673-
case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1)
674-
}
700+
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
675701
featureIndex += 1
676702
}
677703
}
678704
nodeIndex += 1
679705
}
680706
}
681707

682-
/**
683-
* Performs a sequential aggregation over a partition for classification. For l nodes,
684-
* k features, either the left count or the right count of one of the p bins is
685-
* incremented based upon whether the feature is classified as 0 or 1.
686-
*
687-
* @param agg Array[Double] storing aggregate calculation of size
688-
* numClasses * numSplits * numFeatures*numNodes for classification
689-
* @param arr Array[Double] of size 1 + (numFeatures * numNodes)
690-
* @return Array[Double] storing aggregate calculation of size
691-
* 2 * numClasses * numSplits * numFeatures * numNodes for classification
692-
*/
693-
def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
694-
// Iterate over all nodes.
695-
var nodeIndex = 0
696-
while (nodeIndex < numNodes) {
697-
// Check whether the instance was valid for this nodeIndex.
698-
val validSignalIndex = 1 + numFeatures * nodeIndex
699-
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
700-
if (isSampleValidForNode) {
701-
val rightChildShift = numClasses * numBins * numFeatures * numNodes
702-
// actual class label
703-
val label = arr(0)
704-
// Iterate over all features.
705-
var featureIndex = 0
706-
while (featureIndex < numFeatures) {
707-
val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
708-
if (isContinuousFeature) {
709-
// Find the bin index for this feature.
710-
val arrShift = 1 + numFeatures * nodeIndex
711-
val arrIndex = arrShift + featureIndex
712-
// Update the left or right count for one bin.
713-
val aggShift = numClasses * numBins * numFeatures * nodeIndex
714-
val aggIndex
715-
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
716-
label.toInt match {
717-
case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1)
718-
}
719-
} else {
720-
// Find the bin index for this feature.
721-
val arrShift = 1 + numFeatures * nodeIndex
722-
val arrIndex = arrShift + featureIndex
723-
// Update the left or right count for one bin.
724-
val aggShift = numClasses * numBins * numFeatures * nodeIndex
725-
val aggIndex
726-
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
727-
label.toInt match {
728-
case n: Int =>
729-
// Find all matching bins and increment their values
730-
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
731-
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
732-
var binIndex = 0
733-
while (binIndex < numCategoricalBins) {
734-
if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)) {
735-
agg(aggIndex + binIndex)
736-
= agg(aggIndex + binIndex) + labelWeights.getOrElse(n, 1)
737-
} else {
738-
agg(rightChildShift + aggIndex + binIndex)
739-
= agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(n, 1)
740-
741-
}
742-
binIndex += 1
743-
}
744-
}
745-
}
746-
featureIndex += 1
747-
}
708+
/**
709+
* Performs a sequential aggregation over a partition for classification. For l nodes,
710+
* k features, either the left count or the right count of one of the p bins is
711+
* incremented based upon whether the feature is classified as 0 or 1.
712+
*
713+
* @param agg Array[Double] storing aggregate calculation of size
714+
* numClasses * numSplits * numFeatures*numNodes for classification
715+
* @param arr Array[Double] of size 1 + (numFeatures * numNodes)
716+
* @return Array[Double] storing aggregate calculation of size
717+
* 2 * numClasses * numSplits * numFeatures * numNodes for classification
718+
*/
719+
def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
720+
// Iterate over all nodes.
721+
var nodeIndex = 0
722+
while (nodeIndex < numNodes) {
723+
// Check whether the instance was valid for this nodeIndex.
724+
val validSignalIndex = 1 + numFeatures * nodeIndex
725+
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
726+
if (isSampleValidForNode) {
727+
val rightChildShift = numClasses * numBins * numFeatures * numNodes
728+
// actual class label
729+
val label = arr(0)
730+
// Iterate over all features.
731+
var featureIndex = 0
732+
while (featureIndex < numFeatures) {
733+
val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
734+
if (isContinuousFeature) {
735+
updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
736+
} else {
737+
updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift)
738+
}
739+
featureIndex += 1
748740
}
749-
nodeIndex += 1
750741
}
742+
nodeIndex += 1
751743
}
744+
}
752745

753746
/**
754747
* Performs a sequential aggregation over a partition for regression. For l nodes, k features,
@@ -760,7 +753,7 @@ object DecisionTree extends Serializable with Logging {
760753
* @return Array[Double] storing aggregate calculation of size
761754
* 3 * numSplits * numFeatures * numNodes for regression
762755
*/
763-
def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) {
756+
def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
764757
// Iterate over all nodes.
765758
var nodeIndex = 0
766759
while (nodeIndex < numNodes) {
@@ -795,7 +788,7 @@ object DecisionTree extends Serializable with Logging {
795788
def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
796789
strategy.algo match {
797790
case Classification =>
798-
if(isMulticlassClassification) {
791+
if(isMulticlassClassificationWithCategoricalFeatures) {
799792
multiClassificationBinSeqOp(arr, agg)
800793
} else {
801794
binaryClassificationBinSeqOp(arr, agg)
@@ -806,15 +799,8 @@ object DecisionTree extends Serializable with Logging {
806799
}
807800

808801
// Calculate bin aggregate length for classification or regression.
809-
val binAggregateLength = strategy.algo match {
810-
case Classification =>
811-
if (isMulticlassClassification){
812-
2 * numClasses * numBins * numFeatures * numNodes
813-
} else {
814-
2 * numBins * numFeatures * numNodes
815-
}
816-
case Regression => 3 * numBins * numFeatures * numNodes
817-
}
802+
val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
803+
isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
818804
logDebug("binAggregateLength = " + binAggregateLength)
819805

820806
/**
@@ -1024,7 +1010,7 @@ object DecisionTree extends Serializable with Logging {
10241010
}
10251011
}
10261012

1027-
def findAggregateForCategoricalFeatureClassification(
1013+
def findAggForUnorderedFeatureClassification(
10281014
leftNodeAgg: Array[Array[Array[Double]]],
10291015
rightNodeAgg: Array[Array[Array[Double]]],
10301016
featureIndex: Int) {
@@ -1101,12 +1087,12 @@ object DecisionTree extends Serializable with Logging {
11011087
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
11021088
var featureIndex = 0
11031089
while (featureIndex < numFeatures) {
1104-
if (isMulticlassClassification){
1090+
if (isMulticlassClassificationWithCategoricalFeatures){
11051091
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
11061092
if (isFeatureContinuous) {
11071093
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
11081094
} else {
1109-
findAggregateForCategoricalFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
1095+
findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
11101096
}
11111097
} else {
11121098
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
@@ -1214,7 +1200,7 @@ object DecisionTree extends Serializable with Logging {
12141200
def getBinDataForNode(node: Int): Array[Double] = {
12151201
strategy.algo match {
12161202
case Classification =>
1217-
if (isMulticlassClassification) {
1203+
if (isMulticlassClassificationWithCategoricalFeatures) {
12181204
val shift = numClasses * node * numBins * numFeatures
12191205
val rightChildShift = numClasses * numBins * numFeatures * numNodes
12201206
val binsForNode = {
@@ -1251,10 +1237,22 @@ object DecisionTree extends Serializable with Logging {
12511237
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
12521238
node += 1
12531239
}
1254-
12551240
bestSplits
12561241
}
12571242

1243+
private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int,
1244+
isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = {
1245+
algo match {
1246+
case Classification =>
1247+
if (isMulticlassClassificationWithCategoricalFeatures) {
1248+
2 * numClasses * numBins * numFeatures
1249+
} else {
1250+
numClasses * numBins * numFeatures
1251+
}
1252+
case Regression => 3 * numBins * numFeatures
1253+
}
1254+
}
1255+
12581256
/**
12591257
* Returns split and bins for decision tree calculation.
12601258
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
@@ -1288,9 +1286,12 @@ object DecisionTree extends Serializable with Logging {
12881286
*/
12891287
if (strategy.categoricalFeaturesInfo.size > 0) {
12901288
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
1291-
require(numBins > maxCategoriesForFeatures)
1289+
require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
1290+
"in categorical features")
12921291
if (isMulticlassClassification) {
1293-
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1)
1292+
require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1,
1293+
"numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" +
1294+
" with categorical variables")
12941295
}
12951296
}
12961297

@@ -1331,7 +1332,8 @@ object DecisionTree extends Serializable with Logging {
13311332
} else { // Categorical feature
13321333
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
13331334

1334-
// Use different bin/split calculation strategy for multiclass classification
1335+
// Use different bin/split calculation strategy for categorical features in multiclass
1336+
// classification
13351337
if (isMulticlassClassification) {
13361338
// 2^(maxFeatureValue- 1) - 1 combinations
13371339
var index = 0

0 commit comments

Comments
 (0)