Skip to content

Commit 718506b

Browse files
committed
added unit test
1 parent 1517155 commit 718506b

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ object DecisionTreeRunner {
5151
algo: Algo = Classification,
5252
maxDepth: Int = 5,
5353
impurity: ImpurityType = Gini,
54-
maxBins: Int = 20)
54+
maxBins: Int = 100)
5555

5656
def main(args: Array[String]) {
5757
val defaultParams = Params()

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import org.apache.spark.SparkContext
2424
import org.apache.spark.mllib.regression.LabeledPoint
2525
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
2626
import org.apache.spark.mllib.tree.model.Filter
27-
import org.apache.spark.mllib.tree.configuration.Strategy
27+
import org.apache.spark.mllib.tree.model.Split
28+
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
2829
import org.apache.spark.mllib.tree.configuration.Algo._
2930
import org.apache.spark.mllib.tree.configuration.FeatureType._
3031
import org.apache.spark.mllib.linalg.Vectors
@@ -390,6 +391,53 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
390391
assert(bestSplits(0)._2.rightImpurity === 0)
391392
assert(bestSplits(0)._2.predict === 1)
392393
}
394+
395+
test("test second level node building with/without groups") {
396+
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
397+
assert(arr.length === 1000)
398+
val rdd = sc.parallelize(arr)
399+
val strategy = new Strategy(Classification, Entropy, 3, 100)
400+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
401+
assert(splits.length === 2)
402+
assert(splits(0).length === 99)
403+
assert(bins.length === 2)
404+
assert(bins(0).length === 100)
405+
assert(splits(0).length === 99)
406+
assert(bins(0).length === 100)
407+
408+
val leftFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),-1)
409+
val rightFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),1)
410+
val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter))
411+
val parentImpurities = Array(0.5, 0.5, 0.5)
412+
413+
// Single group second level tree construction.
414+
val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters,
415+
splits, bins, 10)
416+
assert(bestSplits.length === 2)
417+
assert(bestSplits(0)._2.gain > 0)
418+
assert(bestSplits(1)._2.gain > 0)
419+
420+
// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
421+
// level tree construction.
422+
val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1,
423+
filters, splits, bins, 0)
424+
assert(bestSplitsWithGroups.length === 2)
425+
assert(bestSplitsWithGroups(0)._2.gain > 0)
426+
assert(bestSplitsWithGroups(1)._2.gain > 0)
427+
428+
// Verify whether the splits obtained using single group and multiple group level
429+
// construction strategies are the same.
430+
for (i <- 0 until bestSplits.length) {
431+
assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
432+
assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
433+
assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
434+
assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
435+
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
436+
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
437+
}
438+
439+
}
440+
393441
}
394442

395443
object DecisionTreeSuite {
@@ -412,6 +460,20 @@ object DecisionTreeSuite {
412460
arr
413461
}
414462

463+
def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
464+
val arr = new Array[LabeledPoint](1000)
465+
for (i <- 0 until 1000){
466+
if (i < 600){
467+
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
468+
arr(i) = lp
469+
} else {
470+
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
471+
arr(i) = lp
472+
}
473+
}
474+
arr
475+
}
476+
415477
def generateCategoricalDataPoints(): Array[LabeledPoint] = {
416478
val arr = new Array[LabeledPoint](1000)
417479
for (i <- 0 until 1000){

0 commit comments

Comments
 (0)