Skip to content

Commit 3f85a17

Browse files
committed
tests for multiclass classification
1 parent 4d5f70c commit 3f85a17

File tree

2 files changed

+143
-25
lines changed

2 files changed

+143
-25
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,8 +1003,6 @@ object DecisionTree extends Serializable with Logging {
10031003
val numBins = if (maxBins <= count) maxBins else count.toInt
10041004
logDebug("numBins = " + numBins)
10051005

1006-
// TODO: Multiclass modification here
1007-
10081006
/*
10091007
* Ensure #bins is always greater than the categories. For multiclass classification,
10101008
* #bins should be greater than 2^(maxCategories - 1) - 1.
@@ -1058,17 +1056,18 @@ object DecisionTree extends Serializable with Logging {
10581056

10591057
// Use different bin/split calculation strategy for multiclass classification
10601058
if (strategy.isMultiClassification) {
1061-
// Iterate from 1 to 2^maxFeatureValue leading to 2^(maxFeatureValue- 1) - 1
1059+
// Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1
10621060
// combinations.
1063-
var index = 1
1064-
while (index < math.pow(2.0, maxFeatureValue).toInt) {
1065-
val categories: List[Double] = extractMultiClassCategories(index, maxFeatureValue)
1061+
var index = 0
1062+
while (index < math.pow(2.0, maxFeatureValue).toInt - 1) {
1063+
val categories: List[Double]
1064+
= extractMultiClassCategories(index + 1, maxFeatureValue)
10661065
splits(featureIndex)(index)
10671066
= new Split(featureIndex, Double.MinValue, Categorical, categories)
10681067
bins(featureIndex)(index) = {
10691068
if (index == 0) {
10701069
new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
1071-
splits(featureIndex)(0), Categorical, index)
1070+
splits(featureIndex)(0), Categorical, Double.MinValue)
10721071
} else {
10731072
new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical,
10741073
Double.MinValue)
@@ -1147,19 +1146,24 @@ object DecisionTree extends Serializable with Logging {
11471146
}
11481147

11491148
/**
1150-
* Nested method to extract list of eligible categories given an index
1149+
* Nested method to extract list of eligible categories given an index. It extracts the
1150+
* position of ones in a binary representation of the input. If binary
1151+
* representation of an number is 01101 (13), the output list should (3.0, 2.0,
1152+
* 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
11511153
*/
1152-
private def extractMultiClassCategories(i: Int, maxFeatureValue: Double): List[Double] = {
1153-
// TODO: Test this
1154+
private[tree] def extractMultiClassCategories(
1155+
input: Int,
1156+
maxFeatureValue: Int): List[Double] = {
11541157
var categories = List[Double]()
11551158
var j = 0
1159+
var bitShiftedInput = input
11561160
while (j < maxFeatureValue) {
1157-
var copy = i
1158-
if (copy % 2 != 0) {
1161+
if (bitShiftedInput % 2 != 0) {
11591162
// updating the list of categories.
11601163
categories = j.toDouble :: categories
11611164
}
1162-
copy = copy >> 1
1165+
//Right shift by one
1166+
bitShiftedInput = bitShiftedInput >> 1
11631167
j += 1
11641168
}
11651169
categories

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 126 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,120 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
231231
assert(bins(1)(3) === null)
232232
}
233233

234+
test("extract categories from a number for multiclass classification") {
235+
val l = DecisionTree.extractMultiClassCategories(13, 10)
236+
assert(l.length === 3)
237+
assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
238+
}
239+
240+
test("split and bin calculations for categorical variables wiht multiclass classification") {
241+
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
242+
assert(arr.length === 1000)
243+
val rdd = sc.parallelize(arr)
244+
val strategy = new Strategy(
245+
Classification,
246+
Gini,
247+
maxDepth = 3,
248+
maxBins = 100,
249+
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
250+
numClassesForClassification = 3)
251+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
252+
253+
// Expecting 2^3 - 1 = 7 bins/splits
254+
assert(splits(0)(0).feature === 0)
255+
assert(splits(0)(0).threshold === Double.MinValue)
256+
assert(splits(0)(0).featureType === Categorical)
257+
assert(splits(0)(0).categories.length === 1)
258+
assert(splits(0)(0).categories.contains(0.0))
259+
assert(splits(1)(0).feature === 1)
260+
assert(splits(1)(0).threshold === Double.MinValue)
261+
assert(splits(1)(0).featureType === Categorical)
262+
assert(splits(1)(0).categories.length === 1)
263+
assert(splits(1)(0).categories.contains(0.0))
264+
265+
assert(splits(0)(1).feature === 0)
266+
assert(splits(0)(1).threshold === Double.MinValue)
267+
assert(splits(0)(1).featureType === Categorical)
268+
assert(splits(0)(1).categories.length === 1)
269+
assert(splits(0)(1).categories.contains(1.0))
270+
assert(splits(1)(1).feature === 1)
271+
assert(splits(1)(1).threshold === Double.MinValue)
272+
assert(splits(1)(1).featureType === Categorical)
273+
assert(splits(1)(1).categories.length === 1)
274+
assert(splits(1)(1).categories.contains(1.0))
275+
276+
assert(splits(0)(2).feature === 0)
277+
assert(splits(0)(2).threshold === Double.MinValue)
278+
assert(splits(0)(2).featureType === Categorical)
279+
assert(splits(0)(2).categories.length === 2)
280+
assert(splits(0)(2).categories.contains(0.0))
281+
assert(splits(0)(2).categories.contains(1.0))
282+
assert(splits(1)(2).feature === 1)
283+
assert(splits(1)(2).threshold === Double.MinValue)
284+
assert(splits(1)(2).featureType === Categorical)
285+
assert(splits(1)(2).categories.length === 2)
286+
assert(splits(1)(2).categories.contains(0.0))
287+
assert(splits(1)(2).categories.contains(1.0))
288+
289+
assert(splits(0)(3) === null)
290+
291+
292+
// Check bins.
293+
294+
assert(bins(0)(0).category === Double.MinValue)
295+
assert(bins(0)(0).lowSplit.categories.length === 0)
296+
assert(bins(0)(0).highSplit.categories.length === 1)
297+
assert(bins(0)(0).highSplit.categories.contains(0.0))
298+
assert(bins(1)(0).category === Double.MinValue)
299+
assert(bins(1)(0).lowSplit.categories.length === 0)
300+
assert(bins(1)(0).highSplit.categories.length === 1)
301+
assert(bins(1)(0).highSplit.categories.contains(0.0))
302+
303+
assert(bins(0)(1).category === Double.MinValue)
304+
assert(bins(0)(1).lowSplit.categories.length === 1)
305+
assert(bins(0)(1).lowSplit.categories.contains(0.0))
306+
assert(bins(0)(1).highSplit.categories.length === 1)
307+
assert(bins(0)(1).highSplit.categories.contains(1.0))
308+
assert(bins(1)(1).category === Double.MinValue)
309+
assert(bins(1)(1).lowSplit.categories.length === 1)
310+
assert(bins(1)(1).lowSplit.categories.contains(0.0))
311+
assert(bins(1)(1).highSplit.categories.length === 1)
312+
assert(bins(1)(1).highSplit.categories.contains(1.0))
313+
314+
assert(bins(0)(2).category === Double.MinValue)
315+
assert(bins(0)(2).lowSplit.categories.length === 1)
316+
assert(bins(0)(2).lowSplit.categories.contains(1.0))
317+
assert(bins(0)(2).highSplit.categories.length === 2)
318+
assert(bins(0)(2).highSplit.categories.contains(1.0))
319+
assert(bins(0)(2).highSplit.categories.contains(0.0))
320+
assert(bins(1)(2).category === Double.MinValue)
321+
assert(bins(1)(2).lowSplit.categories.length === 1)
322+
assert(bins(1)(2).lowSplit.categories.contains(1.0))
323+
assert(bins(1)(2).highSplit.categories.length === 2)
324+
assert(bins(1)(2).highSplit.categories.contains(1.0))
325+
assert(bins(1)(2).highSplit.categories.contains(0.0))
326+
327+
assert(bins(0)(3) === null)
328+
assert(bins(1)(3) === null)
329+
330+
}
331+
332+
test("split and bin calculations for categorical variables with no sample for one category " +
333+
"for multiclass classification") {
334+
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
335+
assert(arr.length === 1000)
336+
val rdd = sc.parallelize(arr)
337+
val strategy = new Strategy(
338+
Classification,
339+
Gini,
340+
maxDepth = 3,
341+
maxBins = 100,
342+
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3),
343+
numClassesForClassification = 3)
344+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
345+
346+
}
347+
234348
test("classification stump with all categorical variables") {
235349
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
236350
assert(arr.length === 1000)
@@ -430,29 +544,29 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
430544

431545
object DecisionTreeSuite {
432546

433-
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
434-
val arr = new Array[LabeledPoint](1000)
547+
def generateOrderedLabeledPointsWithLabel0(): Array[WeightedLabeledPoint] = {
548+
val arr = new Array[WeightedLabeledPoint](1000)
435549
for (i <- 0 until 1000) {
436-
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
550+
val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
437551
arr(i) = lp
438552
}
439553
arr
440554
}
441555

442-
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
443-
val arr = new Array[LabeledPoint](1000)
556+
def generateOrderedLabeledPointsWithLabel1(): Array[WeightedLabeledPoint] = {
557+
val arr = new Array[WeightedLabeledPoint](1000)
444558
for (i <- 0 until 1000) {
445-
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
559+
val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
446560
arr(i) = lp
447561
}
448562
arr
449563
}
450564

451-
def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
452-
val arr = new Array[LabeledPoint](1000)
565+
def generateOrderedLabeledPoints(): Array[WeightedLabeledPoint] = {
566+
val arr = new Array[WeightedLabeledPoint](1000)
453567
for (i <- 0 until 1000) {
454568
if (i < 600) {
455-
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
569+
val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
456570
arr(i) = lp
457571
} else {
458572
val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
@@ -462,11 +576,11 @@ object DecisionTreeSuite {
462576
arr
463577
}
464578

465-
def generateCategoricalDataPoints(): Array[LabeledPoint] = {
466-
val arr = new Array[LabeledPoint](1000)
579+
def generateCategoricalDataPoints(): Array[WeightedLabeledPoint] = {
580+
val arr = new Array[WeightedLabeledPoint](1000)
467581
for (i <- 0 until 1000) {
468582
if (i < 600) {
469-
arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
583+
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(0.0, 1.0))
470584
} else {
471585
arr(i) = new WeightedLabeledPoint(0.0, Vectors.dense(1.0, 0.0))
472586
}

0 commit comments

Comments
 (0)