@@ -46,18 +46,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
4646 * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
4747 * @return a DecisionTreeModel that can be used for prediction
4848 */
49- def train (input : RDD [LabeledPoint ]): DecisionTreeModel = {
50-
51- // Converting from standard instance format to weighted input format for tree training
52- val weightedInput = input.map(x => WeightedLabeledPoint (x.label,x.features))
49+ def train (input : RDD [WeightedLabeledPoint ]): DecisionTreeModel = {
5350
5451 // Cache input RDD for speedup during multiple passes.
55- weightedInput .cache()
52+ input .cache()
5653 logDebug(" algo = " + strategy.algo)
5754
5855 // Find the splits and the corresponding bins (interval between the splits) using a sample
5956 // of the input data.
60- val (splits, bins) = DecisionTree .findSplitsBins(weightedInput , strategy)
57+ val (splits, bins) = DecisionTree .findSplitsBins(input , strategy)
6158 val numBins = bins(0 ).length
6259 logDebug(" numBins = " + numBins)
6360
@@ -74,7 +71,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7471 // dummy value for top node (updated during first split calculation)
7572 val nodes = new Array [Node ](maxNumNodes)
7673 // num features
77- val numFeatures = weightedInput .take(1 )(0 ).features.size
74+ val numFeatures = input .take(1 )(0 ).features.size
7875
7976 // Calculate level for single group construction
8077
@@ -113,7 +110,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
113110 logDebug(" #####################################" )
114111
115112 // Find best split for all nodes at a level.
116- val splitsStatsForLevel = DecisionTree .findBestSplits(weightedInput , parentImpurities,
113+ val splitsStatsForLevel = DecisionTree .findBestSplits(input , parentImpurities,
117114 strategy, level, filters, splits, bins, maxLevelForSingleGroup)
118115
119116 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
@@ -216,7 +213,9 @@ object DecisionTree extends Serializable with Logging {
216213 * @return a DecisionTreeModel that can be used for prediction
217214 */
218215 def train (input : RDD [LabeledPoint ], strategy : Strategy ): DecisionTreeModel = {
219- new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
216+ // Converting from standard instance format to weighted input format for tree training
217+ val weightedInput = input.map(x => WeightedLabeledPoint (x.label,x.features))
218+ new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
220219 }
221220
222221 /**
@@ -238,7 +237,9 @@ object DecisionTree extends Serializable with Logging {
238237 impurity : Impurity ,
239238 maxDepth : Int ): DecisionTreeModel = {
240239 val strategy = new Strategy (algo,impurity,maxDepth)
241- new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
240+ // Converting from standard instance format to weighted input format for tree training
241+ val weightedInput = input.map(x => WeightedLabeledPoint (x.label,x.features))
242+ new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
242243 }
243244
244245
@@ -273,7 +274,9 @@ object DecisionTree extends Serializable with Logging {
273274 categoricalFeaturesInfo : Map [Int ,Int ]): DecisionTreeModel = {
274275 val strategy = new Strategy (algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
275276 categoricalFeaturesInfo)
276- new DecisionTree (strategy).train(input : RDD [LabeledPoint ])
277+ // Converting from standard instance format to weighted input format for tree training
278+ val weightedInput = input.map(x => WeightedLabeledPoint (x.label,x.features))
279+ new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
277280 }
278281
279282 private val InvalidBinIndex = - 1
0 commit comments