@@ -259,37 +259,6 @@ object DecisionTree extends Serializable with Logging {
259259 new DecisionTree (strategy).train(input)
260260 }
261261
262-
263- /**
264- * Method to train a decision tree model where the instances are represented as an RDD of
265- * (label, features) pairs. The method supports binary classification and regression. For the
266- * binary classification, the label for each instance should either be 0 or 1 to denote the two
267- * classes.
268- *
269- * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
270- * training data
271- * @param algo algorithm, classification or regression
272- * @param impurity impurity criterion used for information gain calculation
273- * @param maxDepth maxDepth maximum depth of the tree
274- * @param numClassesForClassification number of classes for classification. Default value of 2.
275- * @param labelWeights A map storing weights for each label to handle unbalanced
276- * datasets. For example, an entry (n -> k) implies the a weight of k is
277- * applied to an instance with label n. It's important to note that labels
278- * are zero-index and take values 0, 1, 2, ... , numClasses - 1.
279- * @return a DecisionTreeModel that can be used for prediction
280- */
281- def train (
282- input : RDD [LabeledPoint ],
283- algo : Algo ,
284- impurity : Impurity ,
285- maxDepth : Int ,
286- numClassesForClassification : Int ,
287- labelWeights : Map [Int ,Int ]): DecisionTreeModel = {
288- val strategy = new Strategy (algo, impurity, maxDepth, numClassesForClassification,
289- labelWeights = labelWeights)
290- new DecisionTree (strategy).train(input)
291- }
292-
293262 /**
294263 * Method to train a decision tree model where the instances are represented as an RDD of
295264 * (label, features) pairs. The decision tree method supports binary classification and
@@ -303,10 +272,6 @@ object DecisionTree extends Serializable with Logging {
303272 * @param impurity criterion used for information gain calculation
304273 * @param maxDepth maximum depth of the tree
305274 * @param numClassesForClassification number of classes for classification. Default value of 2.
306- * @param labelWeights A map storing weights applied to each label for handling unbalanced
307- * datasets. For example, an entry (n -> k) implies the a weight of k is
308- * applied to an instance with label n. It's important to note that labels
309- * are zero-index and take values 0, 1, 2, ... , numClasses - 1.
310275 * @param maxBins maximum number of bins used for splitting features
311276 * @param quantileCalculationStrategy algorithm for calculating quantiles
312277 * @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -322,12 +287,11 @@ object DecisionTree extends Serializable with Logging {
322287 impurity : Impurity ,
323288 maxDepth : Int ,
324289 numClassesForClassification : Int ,
325- labelWeights : Map [Int ,Int ],
326290 maxBins : Int ,
327291 quantileCalculationStrategy : QuantileStrategy ,
328292 categoricalFeaturesInfo : Map [Int ,Int ]): DecisionTreeModel = {
329293 val strategy = new Strategy (algo, impurity, maxDepth, numClassesForClassification, maxBins,
330- quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights )
294+ quantileCalculationStrategy, categoricalFeaturesInfo)
331295 new DecisionTree (strategy).train(input)
332296 }
333297
@@ -442,8 +406,6 @@ object DecisionTree extends Serializable with Logging {
442406 logDebug(" numBins = " + numBins)
443407 val numClasses = strategy.numClassesForClassification
444408 logDebug(" numClasses = " + numClasses)
445- val labelWeights = strategy.labelWeights
446- logDebug(" labelWeights = " + labelWeights)
447409 val isMulticlassClassification = strategy.isMulticlassClassification
448410 logDebug(" isMulticlassClassification = " + isMulticlassClassification)
449411 val isMulticlassClassificationWithCategoricalFeatures
@@ -647,7 +609,7 @@ object DecisionTree extends Serializable with Logging {
647609 val aggIndex = aggShift + numClasses * featureIndex * numBins
648610 + arr(arrIndex).toInt * numClasses
649611 val labelInt = label.toInt
650- agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1 )
612+ agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
651613 }
652614
653615 def updateBinForUnorderedFeature (nodeIndex : Int , featureIndex : Int , arr : Array [Double ],
@@ -667,10 +629,10 @@ object DecisionTree extends Serializable with Logging {
667629 val labelInt = label.toInt
668630 if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
669631 agg(aggIndex + binIndex)
670- = agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1 )
632+ = agg(aggIndex + binIndex) + 1
671633 } else {
672634 agg(rightChildShift + aggIndex + binIndex)
673- = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1 )
635+ = agg(rightChildShift + aggIndex + binIndex) + 1
674636 }
675637 binIndex += 1
676638 }
0 commit comments