@@ -24,7 +24,7 @@ import org.scalatest.FunSuite
2424import org .apache .spark .mllib .tree .configuration .Algo ._
2525import org .apache .spark .mllib .tree .configuration .FeatureType ._
2626import org .apache .spark .mllib .tree .configuration .Strategy
27- import org .apache .spark .mllib .tree .impl .{DTMetadata , TreePoint }
27+ import org .apache .spark .mllib .tree .impl .{DecisionTreeMetadata , TreePoint }
2828import org .apache .spark .mllib .tree .impurity .{Entropy , Gini , Variance }
2929import org .apache .spark .mllib .tree .model .{DecisionTreeModel , Node }
3030import org .apache .spark .mllib .linalg .Vectors
@@ -64,7 +64,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
6464 assert(arr.length === 1000 )
6565 val rdd = sc.parallelize(arr)
6666 val strategy = new Strategy (Classification , Gini , 3 , 2 , 100 )
67- val metadata = DTMetadata .buildMetadata(rdd, strategy)
67+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
6868 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
6969 assert(splits.length === 2 )
7070 assert(bins.length === 2 )
@@ -83,7 +83,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
8383 numClassesForClassification = 2 ,
8484 maxBins = 100 ,
8585 categoricalFeaturesInfo = Map (0 -> 2 , 1 -> 2 ))
86- val metadata = DTMetadata .buildMetadata(rdd, strategy)
86+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
8787 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
8888 assert(splits.length === 2 )
8989 assert(bins.length === 2 )
@@ -164,7 +164,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
164164 numClassesForClassification = 2 ,
165165 maxBins = 100 ,
166166 categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
167- val metadata = DTMetadata .buildMetadata(rdd, strategy)
167+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
168168 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
169169
170170 // Check splits.
@@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
282282 numClassesForClassification = 100 ,
283283 maxBins = 100 ,
284284 categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
285- val metadata = DTMetadata .buildMetadata(rdd, strategy)
285+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
286286 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
287287
288288 // Expecting 2^2 - 1 = 3 bins/splits
@@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
377377 numClassesForClassification = 100 ,
378378 maxBins = 100 ,
379379 categoricalFeaturesInfo = Map (0 -> 10 , 1 -> 10 ))
380- val metadata = DTMetadata .buildMetadata(rdd, strategy)
380+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
381381 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
382382
383383 // 2^10 - 1 > 100, so categorical variables will be ordered
@@ -433,7 +433,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
433433 maxDepth = 2 ,
434434 maxBins = 100 ,
435435 categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
436- val metadata = DTMetadata .buildMetadata(rdd, strategy)
436+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
437437 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
438438 val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
439439 val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (7 ), metadata, 0 ,
@@ -462,7 +462,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
462462 maxDepth = 2 ,
463463 maxBins = 100 ,
464464 categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
465- val metadata = DTMetadata .buildMetadata(rdd, strategy)
465+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
466466 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
467467 val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
468468 val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (7 ), metadata, 0 ,
@@ -502,7 +502,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
502502 assert(arr.length === 1000 )
503503 val rdd = sc.parallelize(arr)
504504 val strategy = new Strategy (Classification , Gini , 3 , 2 , 100 )
505- val metadata = DTMetadata .buildMetadata(rdd, strategy)
505+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
506506 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
507507 assert(splits.length === 2 )
508508 assert(splits(0 ).length === 99 )
@@ -526,7 +526,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
526526 assert(arr.length === 1000 )
527527 val rdd = sc.parallelize(arr)
528528 val strategy = new Strategy (Classification , Gini , 3 , 2 , 100 )
529- val metadata = DTMetadata .buildMetadata(rdd, strategy)
529+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
530530 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
531531 assert(splits.length === 2 )
532532 assert(splits(0 ).length === 99 )
@@ -551,7 +551,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
551551 assert(arr.length === 1000 )
552552 val rdd = sc.parallelize(arr)
553553 val strategy = new Strategy (Classification , Entropy , 3 , 2 , 100 )
554- val metadata = DTMetadata .buildMetadata(rdd, strategy)
554+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
555555 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
556556 assert(splits.length === 2 )
557557 assert(splits(0 ).length === 99 )
@@ -576,7 +576,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
576576 assert(arr.length === 1000 )
577577 val rdd = sc.parallelize(arr)
578578 val strategy = new Strategy (Classification , Entropy , 3 , 2 , 100 )
579- val metadata = DTMetadata .buildMetadata(rdd, strategy)
579+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
580580 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
581581 assert(splits.length === 2 )
582582 assert(splits(0 ).length === 99 )
@@ -601,7 +601,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
601601 assert(arr.length === 1000 )
602602 val rdd = sc.parallelize(arr)
603603 val strategy = new Strategy (Classification , Entropy , 3 , 2 , 100 )
604- val metadata = DTMetadata .buildMetadata(rdd, strategy)
604+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
605605 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
606606 assert(splits.length === 2 )
607607 assert(splits(0 ).length === 99 )
@@ -653,7 +653,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
653653 val rdd = sc.parallelize(arr)
654654 val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
655655 numClassesForClassification = 3 , categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
656- val metadata = DTMetadata .buildMetadata(rdd, strategy)
656+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
657657 assert(strategy.isMulticlassClassification)
658658 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
659659 val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
@@ -710,7 +710,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
710710 numClassesForClassification = 3 , maxBins = maxBins,
711711 categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
712712 assert(strategy.isMulticlassClassification)
713- val metadata = DTMetadata .buildMetadata(rdd, strategy)
713+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
714714
715715 val model = DecisionTree .train(rdd, strategy)
716716 validateClassifier(model, arr, 1.0 )
@@ -739,7 +739,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
739739 val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
740740 numClassesForClassification = 3 )
741741 assert(strategy.isMulticlassClassification)
742- val metadata = DTMetadata .buildMetadata(rdd, strategy)
742+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
743743
744744 val model = DecisionTree .train(rdd, strategy)
745745 validateClassifier(model, arr, 0.9 )
@@ -765,7 +765,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
765765 val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
766766 numClassesForClassification = 3 , categoricalFeaturesInfo = Map (0 -> 3 ))
767767 assert(strategy.isMulticlassClassification)
768- val metadata = DTMetadata .buildMetadata(rdd, strategy)
768+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
769769
770770 val model = DecisionTree .train(rdd, strategy)
771771 validateClassifier(model, arr, 0.9 )
@@ -790,7 +790,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
790790 val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 4 ,
791791 numClassesForClassification = 3 , categoricalFeaturesInfo = Map (0 -> 10 , 1 -> 10 ))
792792 assert(strategy.isMulticlassClassification)
793- val metadata = DTMetadata .buildMetadata(rdd, strategy)
793+ val metadata = DecisionTreeMetadata .buildMetadata(rdd, strategy)
794794
795795 val (splits, bins) = DecisionTree .findSplitsBins(rdd, metadata)
796796 val treeInput = TreePoint .convertToTreeRDD(rdd, bins, metadata)
0 commit comments