@@ -133,7 +133,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
133133 maxDepth = 3 ,
134134 numClassesForClassification = 2 ,
135135 maxBins = 100 ,
136- categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
136+ categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
137137 val (splits, bins) = DecisionTree .findSplitsBins(rdd, strategy)
138138
139139 // Check splits.
@@ -483,7 +483,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
483483 assert(bestSplits(0 )._2.predict === 1 )
484484 }
485485
486- test(" test second level node building with/without groups" ) {
486+ test(" second level node building with/without groups" ) {
487487 val arr = DecisionTreeSuite .generateOrderedLabeledPoints()
488488 assert(arr.length === 1000 )
489489 val rdd = sc.parallelize(arr)
@@ -529,6 +529,33 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
529529
530530 }
531531
532+ test(" stump with continuous variables for multiclass classification" ) {
533+ assert(true == true )
534+ }
535+
536+ test(" stump with categorical variables for multiclass classification" ) {
537+ val arr = DecisionTreeSuite .generateCategoricalDataPointsForMulticlass()
538+ val input = sc.parallelize(arr)
539+ val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 5 ,
540+ numClassesForClassification = 3 , categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
541+ assert(strategy.isMulticlassClassification)
542+ val (splits, bins) = DecisionTree .findSplitsBins(input, strategy)
543+ val bestSplits = DecisionTree .findBestSplits(input, new Array (31 ), strategy, 0 ,
544+ Array [List [Filter ]](), splits, bins, 10 )
545+
546+ assert(bestSplits.length === 1 )
547+ val bestSplit = bestSplits(0 )._1
548+ assert(bestSplit.feature === 0 )
549+ assert(bestSplit.categories.length === 1 )
550+ assert(bestSplit.categories.contains(0 ))
551+ assert(bestSplit.featureType === Categorical )
552+ println(bestSplit)
553+ }
554+
555+ test(" stump with continuous + categorical variables for multiclass classification" ) {
556+ assert(true == true )
557+ }
558+
532559}
533560
534561object DecisionTreeSuite {
@@ -576,4 +603,22 @@ object DecisionTreeSuite {
576603 }
577604 arr
578605 }
606+
607+ def generateCategoricalDataPointsForMulticlass (): Array [WeightedLabeledPoint ] = {
608+ val arr = new Array [WeightedLabeledPoint ](3000 )
609+ for (i <- 0 until 3000 ) {
610+ if (i < 1000 ) {
611+ arr(i) = new WeightedLabeledPoint (2.0 , Vectors .dense(2.0 , 2.0 ))
612+ } else if (i < 2000 ) {
613+ arr(i) = new WeightedLabeledPoint (1.0 , Vectors .dense(1.0 , 2.0 ))
614+ } else {
615+ arr(i) = new WeightedLabeledPoint (2.0 , Vectors .dense(2.0 , 2.0 ))
616+ }
617+ }
618+ println(arr(0 ))
619+ println(arr(1000 ))
620+ println(arr(2000 ))
621+ arr
622+ }
623+
579624}
0 commit comments