@@ -24,7 +24,8 @@ import org.apache.spark.SparkContext
2424import org .apache .spark .mllib .regression .LabeledPoint
2525import org .apache .spark .mllib .tree .impurity .{Entropy , Gini , Variance }
2626import org .apache .spark .mllib .tree .model .Filter
27- import org .apache .spark .mllib .tree .configuration .Strategy
27+ import org .apache .spark .mllib .tree .model .Split
28+ import org .apache .spark .mllib .tree .configuration .{FeatureType , Strategy }
2829import org .apache .spark .mllib .tree .configuration .Algo ._
2930import org .apache .spark .mllib .tree .configuration .FeatureType ._
3031import org .apache .spark .mllib .linalg .Vectors
@@ -390,6 +391,53 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
390391 assert(bestSplits(0 )._2.rightImpurity === 0 )
391392 assert(bestSplits(0 )._2.predict === 1 )
392393 }
394+
395+ test(" test second level node building with/without groups" ) {
396+ val arr = DecisionTreeSuite .generateOrderedLabeledPoints()
397+ assert(arr.length === 1000 )
398+ val rdd = sc.parallelize(arr)
399+ val strategy = new Strategy (Classification , Entropy , 3 , 100 )
400+ val (splits, bins) = DecisionTree .findSplitsBins(rdd, strategy)
401+ assert(splits.length === 2 )
402+ assert(splits(0 ).length === 99 )
403+ assert(bins.length === 2 )
404+ assert(bins(0 ).length === 100 )
405+ assert(splits(0 ).length === 99 )
406+ assert(bins(0 ).length === 100 )
407+
408+ val leftFilter = Filter (new Split (0 ,400 ,FeatureType .Continuous ,List ()),- 1 )
409+ val rightFilter = Filter (new Split (0 ,400 ,FeatureType .Continuous ,List ()),1 )
410+ val filters = Array [List [Filter ]](List (),List (leftFilter),List (rightFilter))
411+ val parentImpurities = Array (0.5 , 0.5 , 0.5 )
412+
413+ // Single group second level tree construction.
414+ val bestSplits = DecisionTree .findBestSplits(rdd, parentImpurities, strategy, 1 , filters,
415+ splits, bins, 10 )
416+ assert(bestSplits.length === 2 )
417+ assert(bestSplits(0 )._2.gain > 0 )
418+ assert(bestSplits(1 )._2.gain > 0 )
419+
420+ // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
421+ // level tree construction.
422+ val bestSplitsWithGroups = DecisionTree .findBestSplits(rdd, parentImpurities, strategy, 1 ,
423+ filters, splits, bins, 0 )
424+ assert(bestSplitsWithGroups.length === 2 )
425+ assert(bestSplitsWithGroups(0 )._2.gain > 0 )
426+ assert(bestSplitsWithGroups(1 )._2.gain > 0 )
427+
428+ // Verify whether the splits obtained using single group and multiple group level
429+ // construction strategies are the same.
430+ for (i <- 0 until bestSplits.length) {
431+ assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
432+ assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
433+ assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
434+ assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
435+ assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
436+ assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
437+ }
438+
439+ }
440+
393441}
394442
395443object DecisionTreeSuite {
@@ -412,6 +460,20 @@ object DecisionTreeSuite {
412460 arr
413461 }
414462
463+ def generateOrderedLabeledPoints (): Array [LabeledPoint ] = {
464+ val arr = new Array [LabeledPoint ](1000 )
465+ for (i <- 0 until 1000 ){
466+ if (i < 600 ){
467+ val lp = new LabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
468+ arr(i) = lp
469+ } else {
470+ val lp = new LabeledPoint (1.0 , Vectors .dense(i.toDouble, 1000.0 - i))
471+ arr(i) = lp
472+ }
473+ }
474+ arr
475+ }
476+
415477 def generateCategoricalDataPoints (): Array [LabeledPoint ] = {
416478 val arr = new Array [LabeledPoint ](1000 )
417479 for (i <- 0 until 1000 ){
0 commit comments