Skip to content

Commit d88f6be

Browse files
manishamdemengxr
authored andcommitted
[MLlib] SPARK-1536: multiclass classification support for decision tree
The ability to perform multiclass classification is a big advantage for using decision trees and was a highly requested feature for mllib. This pull request adds multiclass classification support to the MLlib decision tree. It also adds sample weights support using WeightedLabeledPoint class for handling unbalanced datasets during classification. It will also support algorithms such as AdaBoost which requires instances to be weighted. It handles the special case where the categorical variables cannot be ordered for multiclass classification and thus the optimizations used for speeding up binary classification cannot be directly used for multiclass classification with categorical variables. More specifically, for m categories in a categorical feature, it analyses all the ```2^(m-1) - 1``` categorical splits provided that #splits are less than the maxBins provided in the input. This condition will not be met for features with large number of categories -- using decision trees is not recommended for such datasets in general since the categorical features are favored over continuous features. Moreover, the user can use a combination of tricks (increasing bin size of the tree algorithms, use binary encoding for categorical features or use one-vs-all classification strategy) to avoid these constraints. The new code is accompanied by unit tests and has also been tested on the iris and covtype datasets. cc: mengxr, etrain, hirakendu, atalwalkar, srowen Author: Manish Amde <[email protected]> Author: manishamde <[email protected]> Author: Evan Sparks <[email protected]> Closes apache#886 from manishamde/multiclass and squashes the following commits: 26f8acc [Manish Amde] another attempt at fixing mima c5b2d04 [Manish Amde] more MIMA fixes 1ce7212 [Manish Amde] change problem filter for mima 10fdd82 [Manish Amde] fixing MIMA excludes e1c970d [Manish Amde] merged master abf2901 [Manish Amde] adding classes to MimaExcludes.scala 45e767a [Manish Amde] adding developer api annotation for overriden methods c8428c4 [Manish Amde] fixing weird multiline bug afced16 [Manish Amde] removed label weights support 2d85a48 [Manish Amde] minor: fixed scalastyle issues reprise 4e85f2c [Manish Amde] minor: fixed scalastyle issues b2ae41f [Manish Amde] minor: scalastyle e4c1321 [Manish Amde] using while loop for regression histograms d75ac32 [Manish Amde] removed WeightedLabeledPoint from this PR 0fecd38 [Manish Amde] minor: add newline to EOF 2061cf5 [Manish Amde] merged from master 06b1690 [Manish Amde] fixed off-by-one error in bin to split conversion 9cc3e31 [Manish Amde] added implicit conversion import 5c1b2ca [Manish Amde] doc for PointConverter class 485eaae [Manish Amde] implicit conversion from LabeledPoint to WeightedLabeledPoint 3d7f911 [Manish Amde] updated doc 8e44ab8 [Manish Amde] updated doc adc7315 [Manish Amde] support ordered categorical splits for multiclass classification e3e8843 [Manish Amde] minor code formatting 23d4268 [Manish Amde] minor: another minor code style 34ee7b9 [Manish Amde] minor: code style 237762d [Manish Amde] renaming functions 12e6d0a [Manish Amde] minor: removing line in doc 9a90c93 [Manish Amde] Merge branch 'master' into multiclass 1892a2c [Manish Amde] tests and use multiclass binaggregate length when atleast one categorical feature is present f5f6b83 [Manish Amde] multiclass for continous variables 8cfd3b6 [Manish Amde] working for categorical multiclass classification 828ff16 [Manish Amde] added categorical variable test bce835f [Manish Amde] code cleanup 7e5f08c [Manish Amde] minor doc 1dd2735 [Manish Amde] bin search logic for multiclass f16a9bb [Manish Amde] fixing while loop d811425 [Manish Amde] multiclass bin aggregate logic ab5cb21 [Manish Amde] multiclass logic d8e4a11 [Manish Amde] sample weights ed5a2df [Manish Amde] fixed classification requirements d012be7 [Manish Amde] fixed while loop 18d2835 [Manish Amde] changing default values for num classes 6b912dc [Manish Amde] added numclasses to tree runner, predict logic for multiclass, add multiclass option to train 75f2bfc [Manish Amde] minor code style fix e547151 [Manish Amde] minor modifications 34549d0 [Manish Amde] fixing error during merge 098e8c5 [Manish Amde] merged master e006f9d [Manish Amde] changing variable names 5c78e1a [Manish Amde] added multiclass support 6c7af22 [Manish Amde] prepared for multiclass without breaking binary classification 46e06ee [Manish Amde] minor mods 3f85a17 [Manish Amde] tests for multiclass classification 4d5f70c [Manish Amde] added multiclass support for find splits bins 46f909c [Manish Amde] todo for multiclass support 455bea9 [Manish Amde] fixed tests 14aea48 [Manish Amde] changing instance format to weighted labeled point a1a6e09 [Manish Amde] added weighted point class 968ca9d [Manish Amde] merged master 7fc9545 [Manish Amde] added docs ce004a1 [Manish Amde] minor formatting b27ad2c [Manish Amde] formatting 426bb28 [Manish Amde] programming guide blurb 8053fed [Manish Amde] more formatting 5eca9e4 [Manish Amde] grammar 4731cda [Manish Amde] formatting 5e82202 [Manish Amde] added documentation, fixed off by 1 error in max level calculation cbd9f14 [Manish Amde] modified scala.math to math dad9652 [Manish Amde] removed unused imports e0426ee [Manish Amde] renamed parameter 718506b [Manish Amde] added unit test 1517155 [Manish Amde] updated documentation 9dbdabe [Manish Amde] merge from master 719d009 [Manish Amde] updating user documentation fecf89a [manishamde] Merge pull request #6 from etrain/deep_tree 0287772 [Evan Sparks] Fixing scalastyle issue. 2f1e093 [Manish Amde] minor: added doc for maxMemory parameter 2f6072c [manishamde] Merge pull request #5 from etrain/deep_tree abc5a23 [Evan Sparks] Parameterizing max memory. 50b143a [Manish Amde] adding support for very deep trees
1 parent 586e716 commit d88f6be

File tree

12 files changed

+926
-258
lines changed

12 files changed

+926
-258
lines changed

docs/mllib-decision-tree.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,17 @@ bins if the condition is not satisfied.
7777

7878
**Categorical features**
7979

80-
For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
81-
binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
80+
For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For
81+
binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the
8282
categorical feature values by the proportion of labels falling in one of the two classes (see
8383
Section 9.2.4 in
8484
[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
8585
details). For example, for a binary classification problem with one categorical feature with three
8686
categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
8787
features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
88-
and A , B \| C where \| denotes the split.
88+
and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification
89+
when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value
90+
is used for ordering.
8991

9092
### Stopping rule
9193

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ object DecisionTreeRunner {
4949
case class Params(
5050
input: String = null,
5151
algo: Algo = Classification,
52+
numClassesForClassification: Int = 2,
5253
maxDepth: Int = 5,
5354
impurity: ImpurityType = Gini,
5455
maxBins: Int = 100)
@@ -68,6 +69,10 @@ object DecisionTreeRunner {
6869
opt[Int]("maxDepth")
6970
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
7071
.action((x, c) => c.copy(maxDepth = x))
72+
opt[Int]("numClassesForClassification")
73+
.text(s"number of classes for classification, "
74+
+ s"default: ${defaultParams.numClassesForClassification}")
75+
.action((x, c) => c.copy(numClassesForClassification = x))
7176
opt[Int]("maxBins")
7277
.text(s"max number of bins, default: ${defaultParams.maxBins}")
7378
.action((x, c) => c.copy(maxBins = x))
@@ -118,7 +123,13 @@ object DecisionTreeRunner {
118123
case Variance => impurity.Variance
119124
}
120125

121-
val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
126+
val strategy
127+
= new Strategy(
128+
algo = params.algo,
129+
impurity = impurityCalculator,
130+
maxDepth = params.maxDepth,
131+
maxBins = params.maxBins,
132+
numClassesForClassification = params.numClassesForClassification)
122133
val model = DecisionTree.train(training, strategy)
123134

124135
if (params.algo == Classification) {
@@ -139,12 +150,8 @@ object DecisionTreeRunner {
139150
*/
140151
private def accuracyScore(
141152
model: DecisionTreeModel,
142-
data: RDD[LabeledPoint],
143-
threshold: Double = 0.5): Double = {
144-
def predictedValue(features: Vector): Double = {
145-
if (model.predict(features) < threshold) 0.0 else 1.0
146-
}
147-
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
153+
data: RDD[LabeledPoint]): Double = {
154+
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
148155
val count = data.count()
149156
correctCount.toDouble / count
150157
}

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 532 additions & 200 deletions
Large diffs are not rendered by default.

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2828
* @param algo classification or regression
2929
* @param impurity criterion used for information gain calculation
3030
* @param maxDepth maximum depth of the tree
31+
* @param numClassesForClassification number of classes for classification. Default value is 2
32+
* leads to binary classification
3133
* @param maxBins maximum number of bins used for splitting features
3234
* @param quantileCalculationStrategy algorithm for calculating quantiles
3335
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
@@ -44,7 +46,15 @@ class Strategy (
4446
val algo: Algo,
4547
val impurity: Impurity,
4648
val maxDepth: Int,
49+
val numClassesForClassification: Int = 2,
4750
val maxBins: Int = 100,
4851
val quantileCalculationStrategy: QuantileStrategy = Sort,
4952
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
50-
val maxMemoryInMB: Int = 128) extends Serializable
53+
val maxMemoryInMB: Int = 128) extends Serializable {
54+
55+
require(numClassesForClassification >= 2)
56+
val isMulticlassClassification = numClassesForClassification > 2
57+
val isMulticlassWithCategoricalFeatures
58+
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
59+
60+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,35 @@ object Entropy extends Impurity {
3131

3232
/**
3333
* :: DeveloperApi ::
34-
* entropy calculation
35-
* @param c0 count of instances with label 0
36-
* @param c1 count of instances with label 1
37-
* @return entropy value
34+
* information calculation for multiclass classification
35+
* @param counts Array[Double] with counts for each label
36+
* @param totalCount sum of counts for all labels
37+
* @return information value
3838
*/
3939
@DeveloperApi
40-
override def calculate(c0: Double, c1: Double): Double = {
41-
if (c0 == 0 || c1 == 0) {
42-
0
43-
} else {
44-
val total = c0 + c1
45-
val f0 = c0 / total
46-
val f1 = c1 / total
47-
-(f0 * log2(f0)) - (f1 * log2(f1))
40+
override def calculate(counts: Array[Double], totalCount: Double): Double = {
41+
val numClasses = counts.length
42+
var impurity = 0.0
43+
var classIndex = 0
44+
while (classIndex < numClasses) {
45+
val classCount = counts(classIndex)
46+
if (classCount != 0) {
47+
val freq = classCount / totalCount
48+
impurity -= freq * log2(freq)
49+
}
50+
classIndex += 1
4851
}
52+
impurity
4953
}
5054

55+
/**
56+
* :: DeveloperApi ::
57+
* variance calculation
58+
* @param count number of instances
59+
* @param sum sum of labels
60+
* @param sumSquares summation of squares of the labels
61+
*/
62+
@DeveloperApi
5163
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
5264
throw new UnsupportedOperationException("Entropy.calculate")
5365
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,32 @@ object Gini extends Impurity {
3030

3131
/**
3232
* :: DeveloperApi ::
33-
* Gini coefficient calculation
34-
* @param c0 count of instances with label 0
35-
* @param c1 count of instances with label 1
36-
* @return Gini coefficient value
33+
* information calculation for multiclass classification
34+
* @param counts Array[Double] with counts for each label
35+
* @param totalCount sum of counts for all labels
36+
* @return information value
3737
*/
3838
@DeveloperApi
39-
override def calculate(c0: Double, c1: Double): Double = {
40-
if (c0 == 0 || c1 == 0) {
41-
0
42-
} else {
43-
val total = c0 + c1
44-
val f0 = c0 / total
45-
val f1 = c1 / total
46-
1 - f0 * f0 - f1 * f1
39+
override def calculate(counts: Array[Double], totalCount: Double): Double = {
40+
val numClasses = counts.length
41+
var impurity = 1.0
42+
var classIndex = 0
43+
while (classIndex < numClasses) {
44+
val freq = counts(classIndex) / totalCount
45+
impurity -= freq * freq
46+
classIndex += 1
4747
}
48+
impurity
4849
}
4950

51+
/**
52+
* :: DeveloperApi ::
53+
* variance calculation
54+
* @param count number of instances
55+
* @param sum sum of labels
56+
* @param sumSquares summation of squares of the labels
57+
*/
58+
@DeveloperApi
5059
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
5160
throw new UnsupportedOperationException("Gini.calculate")
5261
}

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ trait Impurity extends Serializable {
2828

2929
/**
3030
* :: DeveloperApi ::
31-
* information calculation for binary classification
32-
* @param c0 count of instances with label 0
33-
* @param c1 count of instances with label 1
31+
* information calculation for multiclass classification
32+
* @param counts Array[Double] with counts for each label
33+
* @param totalCount sum of counts for all labels
3434
* @return information value
3535
*/
3636
@DeveloperApi
37-
def calculate(c0 : Double, c1 : Double): Double
37+
def calculate(counts: Array[Double], totalCount: Double): Double
3838

3939
/**
4040
* :: DeveloperApi ::

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
2525
*/
2626
@Experimental
2727
object Variance extends Impurity {
28-
override def calculate(c0: Double, c1: Double): Double =
28+
29+
/**
30+
* :: DeveloperApi ::
31+
* information calculation for multiclass classification
32+
* @param counts Array[Double] with counts for each label
33+
* @param totalCount sum of counts for all labels
34+
* @return information value
35+
*/
36+
@DeveloperApi
37+
override def calculate(counts: Array[Double], totalCount: Double): Double =
2938
throw new UnsupportedOperationException("Variance.calculate")
3039

3140
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
2828
* @param highSplit signifying the upper threshold for the continuous feature to be
2929
* accepted in the bin
3030
* @param featureType type of feature -- categorical or continuous
31-
* @param category categorical label value accepted in the bin
31+
* @param category categorical label value accepted in the bin for binary classification
3232
*/
3333
private[tree]
3434
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,19 @@ import org.apache.spark.annotation.DeveloperApi
2727
* @param leftImpurity left node impurity
2828
* @param rightImpurity right node impurity
2929
* @param predict predicted value
30+
* @param prob probability of the label (classification only)
3031
*/
3132
@DeveloperApi
3233
class InformationGainStats(
3334
val gain: Double,
3435
val impurity: Double,
3536
val leftImpurity: Double,
3637
val rightImpurity: Double,
37-
val predict: Double) extends Serializable {
38+
val predict: Double,
39+
val prob: Double = 0.0) extends Serializable {
3840

3941
override def toString = {
40-
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
41-
.format(gain, impurity, leftImpurity, rightImpurity, predict)
42+
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
43+
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
4244
}
4345
}

0 commit comments

Comments
 (0)