@@ -231,6 +231,120 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
231231 assert(bins(1 )(3 ) === null )
232232 }
233233
234+ test(" extract categories from a number for multiclass classification" ) {
235+ val l = DecisionTree .extractMultiClassCategories(13 , 10 )
236+ assert(l.length === 3 )
237+ assert(List (3.0 , 2.0 , 0.0 ).toSeq == l.toSeq)
238+ }
239+
240+ test(" split and bin calculations for categorical variables wiht multiclass classification" ) {
241+ val arr = DecisionTreeSuite .generateCategoricalDataPoints()
242+ assert(arr.length === 1000 )
243+ val rdd = sc.parallelize(arr)
244+ val strategy = new Strategy (
245+ Classification ,
246+ Gini ,
247+ maxDepth = 3 ,
248+ maxBins = 100 ,
249+ categoricalFeaturesInfo = Map (0 -> 2 , 1 -> 2 ),
250+ numClassesForClassification = 3 )
251+ val (splits, bins) = DecisionTree .findSplitsBins(rdd, strategy)
252+
253+ // Expecting 2^3 - 1 = 7 bins/splits
254+ assert(splits(0 )(0 ).feature === 0 )
255+ assert(splits(0 )(0 ).threshold === Double .MinValue )
256+ assert(splits(0 )(0 ).featureType === Categorical )
257+ assert(splits(0 )(0 ).categories.length === 1 )
258+ assert(splits(0 )(0 ).categories.contains(0.0 ))
259+ assert(splits(1 )(0 ).feature === 1 )
260+ assert(splits(1 )(0 ).threshold === Double .MinValue )
261+ assert(splits(1 )(0 ).featureType === Categorical )
262+ assert(splits(1 )(0 ).categories.length === 1 )
263+ assert(splits(1 )(0 ).categories.contains(0.0 ))
264+
265+ assert(splits(0 )(1 ).feature === 0 )
266+ assert(splits(0 )(1 ).threshold === Double .MinValue )
267+ assert(splits(0 )(1 ).featureType === Categorical )
268+ assert(splits(0 )(1 ).categories.length === 1 )
269+ assert(splits(0 )(1 ).categories.contains(1.0 ))
270+ assert(splits(1 )(1 ).feature === 1 )
271+ assert(splits(1 )(1 ).threshold === Double .MinValue )
272+ assert(splits(1 )(1 ).featureType === Categorical )
273+ assert(splits(1 )(1 ).categories.length === 1 )
274+ assert(splits(1 )(1 ).categories.contains(1.0 ))
275+
276+ assert(splits(0 )(2 ).feature === 0 )
277+ assert(splits(0 )(2 ).threshold === Double .MinValue )
278+ assert(splits(0 )(2 ).featureType === Categorical )
279+ assert(splits(0 )(2 ).categories.length === 2 )
280+ assert(splits(0 )(2 ).categories.contains(0.0 ))
281+ assert(splits(0 )(2 ).categories.contains(1.0 ))
282+ assert(splits(1 )(2 ).feature === 1 )
283+ assert(splits(1 )(2 ).threshold === Double .MinValue )
284+ assert(splits(1 )(2 ).featureType === Categorical )
285+ assert(splits(1 )(2 ).categories.length === 2 )
286+ assert(splits(1 )(2 ).categories.contains(0.0 ))
287+ assert(splits(1 )(2 ).categories.contains(1.0 ))
288+
289+ assert(splits(0 )(3 ) === null )
290+
291+
292+ // Check bins.
293+
294+ assert(bins(0 )(0 ).category === Double .MinValue )
295+ assert(bins(0 )(0 ).lowSplit.categories.length === 0 )
296+ assert(bins(0 )(0 ).highSplit.categories.length === 1 )
297+ assert(bins(0 )(0 ).highSplit.categories.contains(0.0 ))
298+ assert(bins(1 )(0 ).category === Double .MinValue )
299+ assert(bins(1 )(0 ).lowSplit.categories.length === 0 )
300+ assert(bins(1 )(0 ).highSplit.categories.length === 1 )
301+ assert(bins(1 )(0 ).highSplit.categories.contains(0.0 ))
302+
303+ assert(bins(0 )(1 ).category === Double .MinValue )
304+ assert(bins(0 )(1 ).lowSplit.categories.length === 1 )
305+ assert(bins(0 )(1 ).lowSplit.categories.contains(0.0 ))
306+ assert(bins(0 )(1 ).highSplit.categories.length === 1 )
307+ assert(bins(0 )(1 ).highSplit.categories.contains(1.0 ))
308+ assert(bins(1 )(1 ).category === Double .MinValue )
309+ assert(bins(1 )(1 ).lowSplit.categories.length === 1 )
310+ assert(bins(1 )(1 ).lowSplit.categories.contains(0.0 ))
311+ assert(bins(1 )(1 ).highSplit.categories.length === 1 )
312+ assert(bins(1 )(1 ).highSplit.categories.contains(1.0 ))
313+
314+ assert(bins(0 )(2 ).category === Double .MinValue )
315+ assert(bins(0 )(2 ).lowSplit.categories.length === 1 )
316+ assert(bins(0 )(2 ).lowSplit.categories.contains(1.0 ))
317+ assert(bins(0 )(2 ).highSplit.categories.length === 2 )
318+ assert(bins(0 )(2 ).highSplit.categories.contains(1.0 ))
319+ assert(bins(0 )(2 ).highSplit.categories.contains(0.0 ))
320+ assert(bins(1 )(2 ).category === Double .MinValue )
321+ assert(bins(1 )(2 ).lowSplit.categories.length === 1 )
322+ assert(bins(1 )(2 ).lowSplit.categories.contains(1.0 ))
323+ assert(bins(1 )(2 ).highSplit.categories.length === 2 )
324+ assert(bins(1 )(2 ).highSplit.categories.contains(1.0 ))
325+ assert(bins(1 )(2 ).highSplit.categories.contains(0.0 ))
326+
327+ assert(bins(0 )(3 ) === null )
328+ assert(bins(1 )(3 ) === null )
329+
330+ }
331+
332+ test(" split and bin calculations for categorical variables with no sample for one category " +
333+ " for multiclass classification" ) {
334+ val arr = DecisionTreeSuite .generateCategoricalDataPoints()
335+ assert(arr.length === 1000 )
336+ val rdd = sc.parallelize(arr)
337+ val strategy = new Strategy (
338+ Classification ,
339+ Gini ,
340+ maxDepth = 3 ,
341+ maxBins = 100 ,
342+ categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ),
343+ numClassesForClassification = 3 )
344+ val (splits, bins) = DecisionTree .findSplitsBins(rdd, strategy)
345+
346+ }
347+
234348 test(" classification stump with all categorical variables" ) {
235349 val arr = DecisionTreeSuite .generateCategoricalDataPoints()
236350 assert(arr.length === 1000 )
@@ -430,29 +544,29 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
430544
431545object DecisionTreeSuite {
432546
433- def generateOrderedLabeledPointsWithLabel0 (): Array [LabeledPoint ] = {
434- val arr = new Array [LabeledPoint ](1000 )
547+ def generateOrderedLabeledPointsWithLabel0 (): Array [WeightedLabeledPoint ] = {
548+ val arr = new Array [WeightedLabeledPoint ](1000 )
435549 for (i <- 0 until 1000 ) {
436- val lp = new LabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
550+ val lp = new WeightedLabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
437551 arr(i) = lp
438552 }
439553 arr
440554 }
441555
442- def generateOrderedLabeledPointsWithLabel1 (): Array [LabeledPoint ] = {
443- val arr = new Array [LabeledPoint ](1000 )
556+ def generateOrderedLabeledPointsWithLabel1 (): Array [WeightedLabeledPoint ] = {
557+ val arr = new Array [WeightedLabeledPoint ](1000 )
444558 for (i <- 0 until 1000 ) {
445- val lp = new LabeledPoint (1.0 , Vectors .dense(i.toDouble, 999.0 - i))
559+ val lp = new WeightedLabeledPoint (1.0 , Vectors .dense(i.toDouble, 999.0 - i))
446560 arr(i) = lp
447561 }
448562 arr
449563 }
450564
451- def generateOrderedLabeledPoints (): Array [LabeledPoint ] = {
452- val arr = new Array [LabeledPoint ](1000 )
565+ def generateOrderedLabeledPoints (): Array [WeightedLabeledPoint ] = {
566+ val arr = new Array [WeightedLabeledPoint ](1000 )
453567 for (i <- 0 until 1000 ) {
454568 if (i < 600 ) {
455- val lp = new LabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
569+ val lp = new WeightedLabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
456570 arr(i) = lp
457571 } else {
458572 val lp = new WeightedLabeledPoint (1.0 , Vectors .dense(i.toDouble, 1000.0 - i))
@@ -462,11 +576,11 @@ object DecisionTreeSuite {
462576 arr
463577 }
464578
465- def generateCategoricalDataPoints (): Array [LabeledPoint ] = {
466- val arr = new Array [LabeledPoint ](1000 )
579+ def generateCategoricalDataPoints (): Array [WeightedLabeledPoint ] = {
580+ val arr = new Array [WeightedLabeledPoint ](1000 )
467581 for (i <- 0 until 1000 ) {
468582 if (i < 600 ) {
469- arr(i) = new LabeledPoint (1.0 , Vectors .dense(0.0 , 1.0 ))
583+ arr(i) = new WeightedLabeledPoint (1.0 , Vectors .dense(0.0 , 1.0 ))
470584 } else {
471585 arr(i) = new WeightedLabeledPoint (0.0 , Vectors .dense(1.0 , 0.0 ))
472586 }
0 commit comments