Skip to content

Commit d8e4a11

Browse files
committed
sample weights
1 parent ed5a2df commit d8e4a11

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,39 @@ object DecisionTree extends Serializable with Logging {
268268
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
269269
}
270270

271-
// TODO: Add sample weight support
271+
272+
/**
273+
* Method to train a decision tree model where the instances are represented as an RDD of
274+
* (label, features) pairs. The method supports binary classification and regression. For the
275+
* binary classification, the label for each instance should either be 0 or 1 to denote the two
276+
* classes.
277+
*
278+
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
279+
* training data
280+
* @param algo algorithm, classification or regression
281+
* @param impurity impurity criterion used for information gain calculation
282+
* @param maxDepth maxDepth maximum depth of the tree
283+
* @param numClassesForClassification number of classes for classification. Default value of 2.
284+
* @param labelWeights A map storing weights applied to each label for handling unbalanced
285+
* datasets. For example, an entry (n -> k) implies the a weight of k is
286+
* applied to an instance with label n. It's important to note that labels
287+
* are zero-index and take values 0, 1, 2, ... , numClasses.
288+
* @return a DecisionTreeModel that can be used for prediction
289+
*/
290+
def train(
291+
input: RDD[LabeledPoint],
292+
algo: Algo,
293+
impurity: Impurity,
294+
maxDepth: Int,
295+
numClassesForClassification: Int,
296+
labelWeights: Map[Int,Int]): DecisionTreeModel = {
297+
val strategy
298+
= new Strategy(algo, impurity, maxDepth, numClassesForClassification,
299+
labelWeights = labelWeights)
300+
// Converting from standard instance format to weighted input format for tree training
301+
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
302+
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
303+
}
272304

273305
/**
274306
* Method to train a decision tree model where the instances are represented as an RDD of
@@ -283,6 +315,10 @@ object DecisionTree extends Serializable with Logging {
283315
* @param impurity criterion used for information gain calculation
284316
* @param maxDepth maximum depth of the tree
285317
* @param numClassesForClassification number of classes for classification. Default value of 2.
318+
* @param labelWeights A map storing weights applied to each label for handling unbalanced
319+
* datasets. For example, an entry (n -> k) implies the a weight of k is
320+
* applied to an instance with label n. It's important to note that labels
321+
* are zero-index and take values 0, 1, 2, ... , numClasses.
286322
* @param maxBins maximum number of bins used for splitting features
287323
* @param quantileCalculationStrategy algorithm for calculating quantiles
288324
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -298,11 +334,12 @@ object DecisionTree extends Serializable with Logging {
298334
impurity: Impurity,
299335
maxDepth: Int,
300336
numClassesForClassification: Int,
337+
labelWeights: Map[Int,Int],
301338
maxBins: Int,
302339
quantileCalculationStrategy: QuantileStrategy,
303340
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
304341
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
305-
quantileCalculationStrategy, categoricalFeaturesInfo)
342+
quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights)
306343
// Converting from standard instance format to weighted input format for tree training
307344
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
308345
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
@@ -419,6 +456,9 @@ object DecisionTree extends Serializable with Logging {
419456
logDebug("numBins = " + numBins)
420457
val numClasses = strategy.numClassesForClassification
421458
logDebug("numClasses = " + numClasses)
459+
val labelWeights = strategy.labelWeights
460+
logDebug("labelWeights = " + labelWeights)
461+
422462

423463
// shift when more than one group is used at deep tree level
424464
val groupShift = numNodes * groupIndex
@@ -605,7 +645,8 @@ object DecisionTree extends Serializable with Logging {
605645
val aggIndex
606646
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
607647
label.toInt match {
608-
case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + 1
648+
case n: Int =>
649+
agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1)
609650
}
610651
featureIndex += 1
611652
}
@@ -1010,6 +1051,7 @@ object DecisionTree extends Serializable with Logging {
10101051
while (featureIndex < numFeatures) {
10111052
// Iterate over all splits.
10121053
var splitIndex = 0
1054+
// TODO: Modify this for categorical variables to go over only valid splits
10131055
while (splitIndex < numBins - 1) {
10141056
val gainStats = gains(featureIndex)(splitIndex)
10151057
if (gainStats.gain > bestGainStats.gain) {

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
3939
* zero-indexed.
4040
* @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
4141
* 128 MB.
42+
* @param labelWeights A map storing weights applied to each label for handling unbalanced
43+
* datasets. For example, an entry (n -> k) implies the a weight of k is
44+
* applied to an instance with label n. It's important to note that labels
45+
* are zero-index and take values 0, 1, 2, ... , numClasses.
4246
*
4347
*/
4448
@Experimental
@@ -50,7 +54,8 @@ class Strategy (
5054
val maxBins: Int = 100,
5155
val quantileCalculationStrategy: QuantileStrategy = Sort,
5256
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
53-
val maxMemoryInMB: Int = 128) extends Serializable {
57+
val maxMemoryInMB: Int = 128,
58+
val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable {
5459

5560
require(numClassesForClassification >= 2)
5661
val isMultiClassification = numClassesForClassification > 2

0 commit comments

Comments
 (0)