@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.impurity.Impurity
2828import org .apache .spark .mllib .tree .model ._
2929import org .apache .spark .rdd .RDD
3030import org .apache .spark .util .random .XORShiftRandom
31+ import org .apache .spark .mllib .point .WeightedLabeledPoint
3132
3233/**
3334 * :: Experimental ::
@@ -47,13 +48,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
4748 */
4849 def train (input : RDD [LabeledPoint ]): DecisionTreeModel = {
4950
51+ // Converting from standard instance format to weighted input format for tree training
52+ val weightedInput = input.map(x => WeightedLabeledPoint (x.label,x.features))
53+
5054 // Cache input RDD for speedup during multiple passes.
51- input .cache()
55+ weightedInput .cache()
5256 logDebug(" algo = " + strategy.algo)
5357
5458 // Find the splits and the corresponding bins (interval between the splits) using a sample
5559 // of the input data.
56- val (splits, bins) = DecisionTree .findSplitsBins(input , strategy)
60+ val (splits, bins) = DecisionTree .findSplitsBins(weightedInput , strategy)
5761 val numBins = bins(0 ).length
5862 logDebug(" numBins = " + numBins)
5963
@@ -70,7 +74,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7074 // dummy value for top node (updated during first split calculation)
7175 val nodes = new Array [Node ](maxNumNodes)
7276 // num features
73- val numFeatures = input .take(1 )(0 ).features.size
77+ val numFeatures = weightedInput .take(1 )(0 ).features.size
7478
7579 // Calculate level for single group construction
7680
@@ -109,8 +113,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
109113 logDebug(" #####################################" )
110114
111115 // Find best split for all nodes at a level.
112- val splitsStatsForLevel = DecisionTree .findBestSplits(input , parentImpurities, strategy ,
113- level, filters, splits, bins, maxLevelForSingleGroup)
116+ val splitsStatsForLevel = DecisionTree .findBestSplits(weightedInput , parentImpurities,
117+ strategy, level, filters, splits, bins, maxLevelForSingleGroup)
114118
115119 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
116120 // Extract info for nodes at the current level.
@@ -291,7 +295,7 @@ object DecisionTree extends Serializable with Logging {
291295 * @return array of splits with best splits for all nodes at a given level.
292296 */
293297 protected [tree] def findBestSplits (
294- input : RDD [LabeledPoint ],
298+ input : RDD [WeightedLabeledPoint ],
295299 parentImpurities : Array [Double ],
296300 strategy : Strategy ,
297301 level : Int ,
@@ -339,7 +343,7 @@ object DecisionTree extends Serializable with Logging {
339343 * @return array of splits with best splits for all nodes at a given level.
340344 */
341345 private def findBestSplitsPerGroup (
342- input : RDD [LabeledPoint ],
346+ input : RDD [WeightedLabeledPoint ],
343347 parentImpurities : Array [Double ],
344348 strategy : Strategy ,
345349 level : Int ,
@@ -399,7 +403,7 @@ object DecisionTree extends Serializable with Logging {
399403 * Find whether the sample is valid input for the current node, i.e., whether it passes through
400404 * all the filters for the current node.
401405 */
402- def isSampleValid (parentFilters : List [Filter ], labeledPoint : LabeledPoint ): Boolean = {
406+ def isSampleValid (parentFilters : List [Filter ], labeledPoint : WeightedLabeledPoint ): Boolean = {
403407 // leaf
404408 if ((level > 0 ) & (parentFilters.length == 0 )) {
405409 return false
@@ -438,7 +442,7 @@ object DecisionTree extends Serializable with Logging {
438442 */
439443 def findBin (
440444 featureIndex : Int ,
441- labeledPoint : LabeledPoint ,
445+ labeledPoint : WeightedLabeledPoint ,
442446 isFeatureContinuous : Boolean ): Int = {
443447 val binForFeatures = bins(featureIndex)
444448 val feature = labeledPoint.features(featureIndex)
@@ -509,7 +513,7 @@ object DecisionTree extends Serializable with Logging {
509513 * where b_ij is an integer between 0 and numBins - 1.
510514 * Invalid sample is denoted by noting bin for feature 1 as -1.
511515 */
512- def findBinsForLevel (labeledPoint : LabeledPoint ): Array [Double ] = {
516+ def findBinsForLevel (labeledPoint : WeightedLabeledPoint ): Array [Double ] = {
513517 // Calculate bin index and label per feature per node.
514518 val arr = new Array [Double ](1 + (numFeatures * numNodes))
515519 arr(0 ) = labeledPoint.label
@@ -982,7 +986,7 @@ object DecisionTree extends Serializable with Logging {
982986 * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
983987 */
984988 protected [tree] def findSplitsBins (
985- input : RDD [LabeledPoint ],
989+ input : RDD [WeightedLabeledPoint ],
986990 strategy : Strategy ): (Array [Array [Split ]], Array [Array [Bin ]]) = {
987991 val count = input.count()
988992
0 commit comments