@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree
1919
2020import org .scalatest .FunSuite
2121
22- import org .apache .spark .mllib .point .WeightedLabeledPoint
2322import org .apache .spark .mllib .tree .impurity .{Entropy , Gini , Variance }
2423import org .apache .spark .mllib .tree .model .Filter
2524import org .apache .spark .mllib .tree .model .Split
@@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
2827import org .apache .spark .mllib .tree .configuration .FeatureType ._
2928import org .apache .spark .mllib .linalg .Vectors
3029import org .apache .spark .mllib .util .LocalSparkContext
30+ import org .apache .spark .mllib .regression .LabeledPoint
3131
3232class DecisionTreeSuite extends FunSuite with LocalSparkContext {
3333
@@ -664,86 +664,86 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
664664
665665object DecisionTreeSuite {
666666
667- def generateOrderedLabeledPointsWithLabel0 (): Array [WeightedLabeledPoint ] = {
668- val arr = new Array [WeightedLabeledPoint ](1000 )
667+ def generateOrderedLabeledPointsWithLabel0 (): Array [LabeledPoint ] = {
668+ val arr = new Array [LabeledPoint ](1000 )
669669 for (i <- 0 until 1000 ) {
670- val lp = new WeightedLabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
670+ val lp = new LabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
671671 arr(i) = lp
672672 }
673673 arr
674674 }
675675
676- def generateOrderedLabeledPointsWithLabel1 (): Array [WeightedLabeledPoint ] = {
677- val arr = new Array [WeightedLabeledPoint ](1000 )
676+ def generateOrderedLabeledPointsWithLabel1 (): Array [LabeledPoint ] = {
677+ val arr = new Array [LabeledPoint ](1000 )
678678 for (i <- 0 until 1000 ) {
679- val lp = new WeightedLabeledPoint (1.0 , Vectors .dense(i.toDouble, 999.0 - i))
679+ val lp = new LabeledPoint (1.0 , Vectors .dense(i.toDouble, 999.0 - i))
680680 arr(i) = lp
681681 }
682682 arr
683683 }
684684
685- def generateOrderedLabeledPoints (): Array [WeightedLabeledPoint ] = {
686- val arr = new Array [WeightedLabeledPoint ](1000 )
685+ def generateOrderedLabeledPoints (): Array [LabeledPoint ] = {
686+ val arr = new Array [LabeledPoint ](1000 )
687687 for (i <- 0 until 1000 ) {
688688 if (i < 600 ) {
689- val lp = new WeightedLabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
689+ val lp = new LabeledPoint (0.0 , Vectors .dense(i.toDouble, 1000.0 - i))
690690 arr(i) = lp
691691 } else {
692- val lp = new WeightedLabeledPoint (1.0 , Vectors .dense(i.toDouble, 1000.0 - i))
692+ val lp = new LabeledPoint (1.0 , Vectors .dense(i.toDouble, 1000.0 - i))
693693 arr(i) = lp
694694 }
695695 }
696696 arr
697697 }
698698
699- def generateCategoricalDataPoints (): Array [WeightedLabeledPoint ] = {
700- val arr = new Array [WeightedLabeledPoint ](1000 )
699+ def generateCategoricalDataPoints (): Array [LabeledPoint ] = {
700+ val arr = new Array [LabeledPoint ](1000 )
701701 for (i <- 0 until 1000 ) {
702702 if (i < 600 ) {
703- arr(i) = new WeightedLabeledPoint (1.0 , Vectors .dense(0.0 , 1.0 ))
703+ arr(i) = new LabeledPoint (1.0 , Vectors .dense(0.0 , 1.0 ))
704704 } else {
705- arr(i) = new WeightedLabeledPoint (0.0 , Vectors .dense(1.0 , 0.0 ))
705+ arr(i) = new LabeledPoint (0.0 , Vectors .dense(1.0 , 0.0 ))
706706 }
707707 }
708708 arr
709709 }
710710
711- def generateCategoricalDataPointsForMulticlass (): Array [WeightedLabeledPoint ] = {
712- val arr = new Array [WeightedLabeledPoint ](3000 )
711+ def generateCategoricalDataPointsForMulticlass (): Array [LabeledPoint ] = {
712+ val arr = new Array [LabeledPoint ](3000 )
713713 for (i <- 0 until 3000 ) {
714714 if (i < 1000 ) {
715- arr(i) = new WeightedLabeledPoint (2.0 , Vectors .dense(2.0 , 2.0 ))
715+ arr(i) = new LabeledPoint (2.0 , Vectors .dense(2.0 , 2.0 ))
716716 } else if (i < 2000 ) {
717- arr(i) = new WeightedLabeledPoint (1.0 , Vectors .dense(1.0 , 2.0 ))
717+ arr(i) = new LabeledPoint (1.0 , Vectors .dense(1.0 , 2.0 ))
718718 } else {
719- arr(i) = new WeightedLabeledPoint (2.0 , Vectors .dense(2.0 , 2.0 ))
719+ arr(i) = new LabeledPoint (2.0 , Vectors .dense(2.0 , 2.0 ))
720720 }
721721 }
722722 arr
723723 }
724724
725- def generateContinuousDataPointsForMulticlass (): Array [WeightedLabeledPoint ] = {
726- val arr = new Array [WeightedLabeledPoint ](3000 )
725+ def generateContinuousDataPointsForMulticlass (): Array [LabeledPoint ] = {
726+ val arr = new Array [LabeledPoint ](3000 )
727727 for (i <- 0 until 3000 ) {
728728 if (i < 2000 ) {
729- arr(i) = new WeightedLabeledPoint (2.0 , Vectors .dense(2.0 , i))
729+ arr(i) = new LabeledPoint (2.0 , Vectors .dense(2.0 , i))
730730 } else {
731- arr(i) = new WeightedLabeledPoint (1.0 , Vectors .dense(2.0 , i))
731+ arr(i) = new LabeledPoint (1.0 , Vectors .dense(2.0 , i))
732732 }
733733 }
734734 arr
735735 }
736736
737737 def generateCategoricalDataPointsForMulticlassForOrderedFeatures ():
738- Array [WeightedLabeledPoint ] = {
739- val arr = new Array [WeightedLabeledPoint ](3000 )
738+ Array [LabeledPoint ] = {
739+ val arr = new Array [LabeledPoint ](3000 )
740740 for (i <- 0 until 3000 ) {
741741 if (i < 1000 ) {
742- arr(i) = new WeightedLabeledPoint (2.0 , Vectors .dense(2.0 , 2.0 ))
742+ arr(i) = new LabeledPoint (2.0 , Vectors .dense(2.0 , 2.0 ))
743743 } else if (i < 2000 ) {
744- arr(i) = new WeightedLabeledPoint (1.0 , Vectors .dense(1.0 , 2.0 ))
744+ arr(i) = new LabeledPoint (1.0 , Vectors .dense(1.0 , 2.0 ))
745745 } else {
746- arr(i) = new WeightedLabeledPoint (1.0 , Vectors .dense(2.0 , 2.0 ))
746+ arr(i) = new LabeledPoint (1.0 , Vectors .dense(2.0 , 2.0 ))
747747 }
748748 }
749749 arr
0 commit comments