Skip to content

Commit 828ff16

Browse files
committed
added categorical variable test
1 parent bce835f commit 828ff16

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
133133
maxDepth = 3,
134134
numClassesForClassification = 2,
135135
maxBins = 100,
136-
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
136+
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
137137
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
138138

139139
// Check splits.
@@ -483,7 +483,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
483483
assert(bestSplits(0)._2.predict === 1)
484484
}
485485

486-
test("test second level node building with/without groups") {
486+
test("second level node building with/without groups") {
487487
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
488488
assert(arr.length === 1000)
489489
val rdd = sc.parallelize(arr)
@@ -529,6 +529,33 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
529529

530530
}
531531

532+
test("stump with continuous variables for multiclass classification") {
533+
assert(true==true)
534+
}
535+
536+
test("stump with categorical variables for multiclass classification") {
537+
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
538+
val input = sc.parallelize(arr)
539+
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
540+
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
541+
assert(strategy.isMulticlassClassification)
542+
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
543+
val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
544+
Array[List[Filter]](), splits, bins, 10)
545+
546+
assert(bestSplits.length === 1)
547+
val bestSplit = bestSplits(0)._1
548+
assert(bestSplit.feature === 0)
549+
assert(bestSplit.categories.length === 1)
550+
assert(bestSplit.categories.contains(0))
551+
assert(bestSplit.featureType === Categorical)
552+
println(bestSplit)
553+
}
554+
555+
test("stump with continuous + categorical variables for multiclass classification") {
556+
assert(true==true)
557+
}
558+
532559
}
533560

534561
object DecisionTreeSuite {
@@ -576,4 +603,22 @@ object DecisionTreeSuite {
576603
}
577604
arr
578605
}
606+
607+
def generateCategoricalDataPointsForMulticlass(): Array[WeightedLabeledPoint] = {
608+
val arr = new Array[WeightedLabeledPoint](3000)
609+
for (i <- 0 until 3000) {
610+
if (i < 1000) {
611+
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
612+
} else if (i < 2000) {
613+
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0))
614+
} else {
615+
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
616+
}
617+
}
618+
println(arr(0))
619+
println(arr(1000))
620+
println(arr(2000))
621+
arr
622+
}
623+
579624
}

0 commit comments

Comments
 (0)