Skip to content

Commit 14aea48

Browse files
committed
changing instance format to weighted labeled point
1 parent a1a6e09 commit 14aea48

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.impurity.Impurity
2828
import org.apache.spark.mllib.tree.model._
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.util.random.XORShiftRandom
31+
import org.apache.spark.mllib.point.WeightedLabeledPoint
3132

3233
/**
3334
* :: Experimental ::
@@ -47,13 +48,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
4748
*/
4849
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
4950

51+
// Converting from standard instance format to weighted input format for tree training
52+
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
53+
5054
// Cache input RDD for speedup during multiple passes.
51-
input.cache()
55+
weightedInput.cache()
5256
logDebug("algo = " + strategy.algo)
5357

5458
// Find the splits and the corresponding bins (interval between the splits) using a sample
5559
// of the input data.
56-
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
60+
val (splits, bins) = DecisionTree.findSplitsBins(weightedInput, strategy)
5761
val numBins = bins(0).length
5862
logDebug("numBins = " + numBins)
5963

@@ -70,7 +74,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7074
// dummy value for top node (updated during first split calculation)
7175
val nodes = new Array[Node](maxNumNodes)
7276
// num features
73-
val numFeatures = input.take(1)(0).features.size
77+
val numFeatures = weightedInput.take(1)(0).features.size
7478

7579
// Calculate level for single group construction
7680

@@ -109,8 +113,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
109113
logDebug("#####################################")
110114

111115
// Find best split for all nodes at a level.
112-
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
113-
level, filters, splits, bins, maxLevelForSingleGroup)
116+
val splitsStatsForLevel = DecisionTree.findBestSplits(weightedInput, parentImpurities,
117+
strategy, level, filters, splits, bins, maxLevelForSingleGroup)
114118

115119
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
116120
// Extract info for nodes at the current level.
@@ -291,7 +295,7 @@ object DecisionTree extends Serializable with Logging {
291295
* @return array of splits with best splits for all nodes at a given level.
292296
*/
293297
protected[tree] def findBestSplits(
294-
input: RDD[LabeledPoint],
298+
input: RDD[WeightedLabeledPoint],
295299
parentImpurities: Array[Double],
296300
strategy: Strategy,
297301
level: Int,
@@ -339,7 +343,7 @@ object DecisionTree extends Serializable with Logging {
339343
* @return array of splits with best splits for all nodes at a given level.
340344
*/
341345
private def findBestSplitsPerGroup(
342-
input: RDD[LabeledPoint],
346+
input: RDD[WeightedLabeledPoint],
343347
parentImpurities: Array[Double],
344348
strategy: Strategy,
345349
level: Int,
@@ -399,7 +403,7 @@ object DecisionTree extends Serializable with Logging {
399403
* Find whether the sample is valid input for the current node, i.e., whether it passes through
400404
* all the filters for the current node.
401405
*/
402-
def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
406+
def isSampleValid(parentFilters: List[Filter], labeledPoint: WeightedLabeledPoint): Boolean = {
403407
// leaf
404408
if ((level > 0) & (parentFilters.length == 0)) {
405409
return false
@@ -438,7 +442,7 @@ object DecisionTree extends Serializable with Logging {
438442
*/
439443
def findBin(
440444
featureIndex: Int,
441-
labeledPoint: LabeledPoint,
445+
labeledPoint: WeightedLabeledPoint,
442446
isFeatureContinuous: Boolean): Int = {
443447
val binForFeatures = bins(featureIndex)
444448
val feature = labeledPoint.features(featureIndex)
@@ -509,7 +513,7 @@ object DecisionTree extends Serializable with Logging {
509513
* where b_ij is an integer between 0 and numBins - 1.
510514
* Invalid sample is denoted by noting bin for feature 1 as -1.
511515
*/
512-
def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
516+
def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = {
513517
// Calculate bin index and label per feature per node.
514518
val arr = new Array[Double](1 + (numFeatures * numNodes))
515519
arr(0) = labeledPoint.label
@@ -982,7 +986,7 @@ object DecisionTree extends Serializable with Logging {
982986
* .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
983987
*/
984988
protected[tree] def findSplitsBins(
985-
input: RDD[LabeledPoint],
989+
input: RDD[WeightedLabeledPoint],
986990
strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
987991
val count = input.count()
988992

0 commit comments

Comments
 (0)