Skip to content

Commit d75ac32

Browse files
committed
removed WeightedLabeledPoint from this PR
1 parent 0fecd38 commit d75ac32

File tree

4 files changed

+36
-105
lines changed

4 files changed

+36
-105
lines changed

mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala

Lines changed: 0 additions & 35 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala

Lines changed: 0 additions & 32 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20-
import org.apache.spark.mllib.point.PointConverter._
2120
import org.apache.spark.annotation.Experimental
2221
import org.apache.spark.Logging
2322
import org.apache.spark.mllib.regression.LabeledPoint
@@ -29,7 +28,6 @@ import org.apache.spark.mllib.tree.impurity.Impurity
2928
import org.apache.spark.mllib.tree.model._
3029
import org.apache.spark.rdd.RDD
3130
import org.apache.spark.util.random.XORShiftRandom
32-
import org.apache.spark.mllib.point.WeightedLabeledPoint
3331

3432
/**
3533
* :: Experimental ::
@@ -47,7 +45,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
4745
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
4846
* @return a DecisionTreeModel that can be used for prediction
4947
*/
50-
def train(input: RDD[WeightedLabeledPoint]): DecisionTreeModel = {
48+
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
5149

5250
// Cache input RDD for speedup during multiple passes.
5351
input.cache()
@@ -352,7 +350,7 @@ object DecisionTree extends Serializable with Logging {
352350
* @return array of splits with best splits for all nodes at a given level.
353351
*/
354352
protected[tree] def findBestSplits(
355-
input: RDD[WeightedLabeledPoint],
353+
input: RDD[LabeledPoint],
356354
parentImpurities: Array[Double],
357355
strategy: Strategy,
358356
level: Int,
@@ -400,7 +398,7 @@ object DecisionTree extends Serializable with Logging {
400398
* @return array of splits with best splits for all nodes at a given level.
401399
*/
402400
private def findBestSplitsPerGroup(
403-
input: RDD[WeightedLabeledPoint],
401+
input: RDD[LabeledPoint],
404402
parentImpurities: Array[Double],
405403
strategy: Strategy,
406404
level: Int,
@@ -469,7 +467,7 @@ object DecisionTree extends Serializable with Logging {
469467
* Find whether the sample is valid input for the current node, i.e., whether it passes through
470468
* all the filters for the current node.
471469
*/
472-
def isSampleValid(parentFilters: List[Filter], labeledPoint: WeightedLabeledPoint): Boolean = {
470+
def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
473471
// leaf
474472
if ((level > 0) && (parentFilters.length == 0)) {
475473
return false
@@ -506,7 +504,7 @@ object DecisionTree extends Serializable with Logging {
506504
/**
507505
* Find bin for one feature.
508506
*/
509-
def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint,
507+
def findBin(featureIndex: Int, labeledPoint: LabeledPoint,
510508
isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
511509
val binForFeatures = bins(featureIndex)
512510
val feature = labeledPoint.features(featureIndex)
@@ -595,7 +593,7 @@ object DecisionTree extends Serializable with Logging {
595593
* classification and the categorical feature value in multiclass classification.
596594
* Invalid sample is denoted by noting bin for feature 1 as -1.
597595
*/
598-
def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = {
596+
def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
599597
// Calculate bin index and label per feature per node.
600598
val arr = new Array[Double](1 + (numFeatures * numNodes))
601599
// First element of the array is the label of the instance.
@@ -1283,7 +1281,7 @@ object DecisionTree extends Serializable with Logging {
12831281
* .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
12841282
*/
12851283
protected[tree] def findSplitsBins(
1286-
input: RDD[WeightedLabeledPoint],
1284+
input: RDD[LabeledPoint],
12871285
strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
12881286
val count = input.count()
12891287

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree
1919

2020
import org.scalatest.FunSuite
2121

22-
import org.apache.spark.mllib.point.WeightedLabeledPoint
2322
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
2423
import org.apache.spark.mllib.tree.model.Filter
2524
import org.apache.spark.mllib.tree.model.Split
@@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
2827
import org.apache.spark.mllib.tree.configuration.FeatureType._
2928
import org.apache.spark.mllib.linalg.Vectors
3029
import org.apache.spark.mllib.util.LocalSparkContext
30+
import org.apache.spark.mllib.regression.LabeledPoint
3131

3232
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
3333

@@ -664,86 +664,86 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
664664

665665
object DecisionTreeSuite {
666666

667-
def generateOrderedLabeledPointsWithLabel0(): Array[WeightedLabeledPoint] = {
668-
val arr = new Array[WeightedLabeledPoint](1000)
667+
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
668+
val arr = new Array[LabeledPoint](1000)
669669
for (i <- 0 until 1000) {
670-
val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
670+
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
671671
arr(i) = lp
672672
}
673673
arr
674674
}
675675

676-
def generateOrderedLabeledPointsWithLabel1(): Array[WeightedLabeledPoint] = {
677-
val arr = new Array[WeightedLabeledPoint](1000)
676+
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
677+
val arr = new Array[LabeledPoint](1000)
678678
for (i <- 0 until 1000) {
679-
val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
679+
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
680680
arr(i) = lp
681681
}
682682
arr
683683
}
684684

685-
def generateOrderedLabeledPoints(): Array[WeightedLabeledPoint] = {
686-
val arr = new Array[WeightedLabeledPoint](1000)
685+
def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
686+
val arr = new Array[LabeledPoint](1000)
687687
for (i <- 0 until 1000) {
688688
if (i < 600) {
689-
val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
689+
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
690690
arr(i) = lp
691691
} else {
692-
val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
692+
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
693693
arr(i) = lp
694694
}
695695
}
696696
arr
697697
}
698698

699-
def generateCategoricalDataPoints(): Array[WeightedLabeledPoint] = {
700-
val arr = new Array[WeightedLabeledPoint](1000)
699+
def generateCategoricalDataPoints(): Array[LabeledPoint] = {
700+
val arr = new Array[LabeledPoint](1000)
701701
for (i <- 0 until 1000) {
702702
if (i < 600) {
703-
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(0.0, 1.0))
703+
arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
704704
} else {
705-
arr(i) = new WeightedLabeledPoint(0.0, Vectors.dense(1.0, 0.0))
705+
arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))
706706
}
707707
}
708708
arr
709709
}
710710

711-
def generateCategoricalDataPointsForMulticlass(): Array[WeightedLabeledPoint] = {
712-
val arr = new Array[WeightedLabeledPoint](3000)
711+
def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
712+
val arr = new Array[LabeledPoint](3000)
713713
for (i <- 0 until 3000) {
714714
if (i < 1000) {
715-
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
715+
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
716716
} else if (i < 2000) {
717-
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0))
717+
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
718718
} else {
719-
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
719+
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
720720
}
721721
}
722722
arr
723723
}
724724

725-
def generateContinuousDataPointsForMulticlass(): Array[WeightedLabeledPoint] = {
726-
val arr = new Array[WeightedLabeledPoint](3000)
725+
def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = {
726+
val arr = new Array[LabeledPoint](3000)
727727
for (i <- 0 until 3000) {
728728
if (i < 2000) {
729-
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, i))
729+
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i))
730730
} else {
731-
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, i))
731+
arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i))
732732
}
733733
}
734734
arr
735735
}
736736

737737
def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
738-
Array[WeightedLabeledPoint] = {
739-
val arr = new Array[WeightedLabeledPoint](3000)
738+
Array[LabeledPoint] = {
739+
val arr = new Array[LabeledPoint](3000)
740740
for (i <- 0 until 3000) {
741741
if (i < 1000) {
742-
arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0))
742+
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
743743
} else if (i < 2000) {
744-
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0))
744+
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
745745
} else {
746-
arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 2.0))
746+
arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
747747
}
748748
}
749749
arr

0 commit comments

Comments
 (0)