Skip to content

Commit e006f9d

Browse files
committed
changing variable names
1 parent 5c78e1a commit e006f9d

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
4646
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
4747
* @return a DecisionTreeModel that can be used for prediction
4848
*/
49-
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
50-
51-
// Converting from standard instance format to weighted input format for tree training
52-
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
49+
def train(input: RDD[WeightedLabeledPoint]): DecisionTreeModel = {
5350

5451
// Cache input RDD for speedup during multiple passes.
55-
weightedInput.cache()
52+
input.cache()
5653
logDebug("algo = " + strategy.algo)
5754

5855
// Find the splits and the corresponding bins (interval between the splits) using a sample
5956
// of the input data.
60-
val (splits, bins) = DecisionTree.findSplitsBins(weightedInput, strategy)
57+
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
6158
val numBins = bins(0).length
6259
logDebug("numBins = " + numBins)
6360

@@ -74,7 +71,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7471
// dummy value for top node (updated during first split calculation)
7572
val nodes = new Array[Node](maxNumNodes)
7673
// num features
77-
val numFeatures = weightedInput.take(1)(0).features.size
74+
val numFeatures = input.take(1)(0).features.size
7875

7976
// Calculate level for single group construction
8077

@@ -113,7 +110,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
113110
logDebug("#####################################")
114111

115112
// Find best split for all nodes at a level.
116-
val splitsStatsForLevel = DecisionTree.findBestSplits(weightedInput, parentImpurities,
113+
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities,
117114
strategy, level, filters, splits, bins, maxLevelForSingleGroup)
118115

119116
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
@@ -216,7 +213,9 @@ object DecisionTree extends Serializable with Logging {
216213
* @return a DecisionTreeModel that can be used for prediction
217214
*/
218215
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
219-
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
216+
// Converting from standard instance format to weighted input format for tree training
217+
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
218+
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
220219
}
221220

222221
/**
@@ -238,7 +237,9 @@ object DecisionTree extends Serializable with Logging {
238237
impurity: Impurity,
239238
maxDepth: Int): DecisionTreeModel = {
240239
val strategy = new Strategy(algo,impurity,maxDepth)
241-
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
240+
// Converting from standard instance format to weighted input format for tree training
241+
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
242+
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
242243
}
243244

244245

@@ -273,7 +274,9 @@ object DecisionTree extends Serializable with Logging {
273274
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
274275
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
275276
categoricalFeaturesInfo)
276-
new DecisionTree(strategy).train(input: RDD[LabeledPoint])
277+
// Converting from standard instance format to weighted input format for tree training
278+
val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features))
279+
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
277280
}
278281

279282
private val InvalidBinIndex = -1

0 commit comments

Comments
 (0)