@@ -268,7 +268,39 @@ object DecisionTree extends Serializable with Logging {
268268 new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
269269 }
270270
271- // TODO: Add sample weight support
271+
272+ /**
273+ * Method to train a decision tree model where the instances are represented as an RDD of
274+ * (label, features) pairs. The method supports binary classification and regression. For the
275+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
276+ * classes.
277+ *
278+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
279+ * training data
280+ * @param algo algorithm, classification or regression
281+ * @param impurity impurity criterion used for information gain calculation
282+ * @param maxDepth maxDepth maximum depth of the tree
283+ * @param numClassesForClassification number of classes for classification. Default value of 2.
284+ * @param labelWeights A map storing weights applied to each label for handling unbalanced
285+ * datasets. For example, an entry (n -> k) implies the a weight of k is
286+ * applied to an instance with label n. It's important to note that labels
287+ * are zero-index and take values 0, 1, 2, ... , numClasses.
288+ * @return a DecisionTreeModel that can be used for prediction
289+ */
290+ def train (
291+ input : RDD [LabeledPoint ],
292+ algo : Algo ,
293+ impurity : Impurity ,
294+ maxDepth : Int ,
295+ numClassesForClassification : Int ,
296+ labelWeights : Map [Int ,Int ]): DecisionTreeModel = {
297+ val strategy
298+ = new Strategy (algo, impurity, maxDepth, numClassesForClassification,
299+ labelWeights = labelWeights)
300+ // Converting from standard instance format to weighted input format for tree training
301+ val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
302+ new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
303+ }
272304
273305 /**
274306 * Method to train a decision tree model where the instances are represented as an RDD of
@@ -283,6 +315,10 @@ object DecisionTree extends Serializable with Logging {
283315 * @param impurity criterion used for information gain calculation
284316 * @param maxDepth maximum depth of the tree
285317 * @param numClassesForClassification number of classes for classification. Default value of 2.
318+ * @param labelWeights A map storing weights applied to each label for handling unbalanced
319+ * datasets. For example, an entry (n -> k) implies the a weight of k is
320+ * applied to an instance with label n. It's important to note that labels
321+ * are zero-index and take values 0, 1, 2, ... , numClasses.
286322 * @param maxBins maximum number of bins used for splitting features
287323 * @param quantileCalculationStrategy algorithm for calculating quantiles
288324 * @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -298,11 +334,12 @@ object DecisionTree extends Serializable with Logging {
298334 impurity : Impurity ,
299335 maxDepth : Int ,
300336 numClassesForClassification : Int ,
337+ labelWeights : Map [Int ,Int ],
301338 maxBins : Int ,
302339 quantileCalculationStrategy : QuantileStrategy ,
303340 categoricalFeaturesInfo : Map [Int ,Int ]): DecisionTreeModel = {
304341 val strategy = new Strategy (algo, impurity, maxDepth, numClassesForClassification, maxBins,
305- quantileCalculationStrategy, categoricalFeaturesInfo)
342+ quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights )
306343 // Converting from standard instance format to weighted input format for tree training
307344 val weightedInput = input.map(x => WeightedLabeledPoint (x.label, x.features))
308345 new DecisionTree (strategy).train(weightedInput : RDD [WeightedLabeledPoint ])
@@ -419,6 +456,9 @@ object DecisionTree extends Serializable with Logging {
419456 logDebug(" numBins = " + numBins)
420457 val numClasses = strategy.numClassesForClassification
421458 logDebug(" numClasses = " + numClasses)
459+ val labelWeights = strategy.labelWeights
460+ logDebug(" labelWeights = " + labelWeights)
461+
422462
423463 // shift when more than one group is used at deep tree level
424464 val groupShift = numNodes * groupIndex
@@ -605,7 +645,8 @@ object DecisionTree extends Serializable with Logging {
605645 val aggIndex
606646 = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
607647 label.toInt match {
608- case n : Int => agg(aggIndex + n) = agg(aggIndex + n) + 1
648+ case n : Int =>
649+ agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1 )
609650 }
610651 featureIndex += 1
611652 }
@@ -1010,6 +1051,7 @@ object DecisionTree extends Serializable with Logging {
10101051 while (featureIndex < numFeatures) {
10111052 // Iterate over all splits.
10121053 var splitIndex = 0
1054+ // TODO: Modify this for categorical variables to go over only valid splits
10131055 while (splitIndex < numBins - 1 ) {
10141056 val gainStats = gains(featureIndex)(splitIndex)
10151057 if (gainStats.gain > bestGainStats.gain) {
0 commit comments