Skip to content

Commit afced16

Browse files
committed
removed label weights support
1 parent 2d85a48 commit afced16

File tree

2 files changed

+5
-48
lines changed

2 files changed

+5
-48
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -259,37 +259,6 @@ object DecisionTree extends Serializable with Logging {
259259
new DecisionTree(strategy).train(input)
260260
}
261261

262-
263-
/**
264-
* Method to train a decision tree model where the instances are represented as an RDD of
265-
* (label, features) pairs. The method supports binary classification and regression. For the
266-
* binary classification, the label for each instance should either be 0 or 1 to denote the two
267-
* classes.
268-
*
269-
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
270-
* training data
271-
* @param algo algorithm, classification or regression
272-
* @param impurity impurity criterion used for information gain calculation
273-
* @param maxDepth maxDepth maximum depth of the tree
274-
* @param numClassesForClassification number of classes for classification. Default value of 2.
275-
* @param labelWeights A map storing weights for each label to handle unbalanced
276-
* datasets. For example, an entry (n -> k) implies the a weight of k is
277-
* applied to an instance with label n. It's important to note that labels
278-
* are zero-index and take values 0, 1, 2, ... , numClasses - 1.
279-
* @return a DecisionTreeModel that can be used for prediction
280-
*/
281-
def train(
282-
input: RDD[LabeledPoint],
283-
algo: Algo,
284-
impurity: Impurity,
285-
maxDepth: Int,
286-
numClassesForClassification: Int,
287-
labelWeights: Map[Int,Int]): DecisionTreeModel = {
288-
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification,
289-
labelWeights = labelWeights)
290-
new DecisionTree(strategy).train(input)
291-
}
292-
293262
/**
294263
* Method to train a decision tree model where the instances are represented as an RDD of
295264
* (label, features) pairs. The decision tree method supports binary classification and
@@ -303,10 +272,6 @@ object DecisionTree extends Serializable with Logging {
303272
* @param impurity criterion used for information gain calculation
304273
* @param maxDepth maximum depth of the tree
305274
* @param numClassesForClassification number of classes for classification. Default value of 2.
306-
* @param labelWeights A map storing weights applied to each label for handling unbalanced
307-
* datasets. For example, an entry (n -> k) implies the a weight of k is
308-
* applied to an instance with label n. It's important to note that labels
309-
* are zero-index and take values 0, 1, 2, ... , numClasses - 1.
310275
* @param maxBins maximum number of bins used for splitting features
311276
* @param quantileCalculationStrategy algorithm for calculating quantiles
312277
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -322,12 +287,11 @@ object DecisionTree extends Serializable with Logging {
322287
impurity: Impurity,
323288
maxDepth: Int,
324289
numClassesForClassification: Int,
325-
labelWeights: Map[Int,Int],
326290
maxBins: Int,
327291
quantileCalculationStrategy: QuantileStrategy,
328292
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
329293
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
330-
quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights)
294+
quantileCalculationStrategy, categoricalFeaturesInfo)
331295
new DecisionTree(strategy).train(input)
332296
}
333297

@@ -442,8 +406,6 @@ object DecisionTree extends Serializable with Logging {
442406
logDebug("numBins = " + numBins)
443407
val numClasses = strategy.numClassesForClassification
444408
logDebug("numClasses = " + numClasses)
445-
val labelWeights = strategy.labelWeights
446-
logDebug("labelWeights = " + labelWeights)
447409
val isMulticlassClassification = strategy.isMulticlassClassification
448410
logDebug("isMulticlassClassification = " + isMulticlassClassification)
449411
val isMulticlassClassificationWithCategoricalFeatures
@@ -647,7 +609,7 @@ object DecisionTree extends Serializable with Logging {
647609
val aggIndex = aggShift + numClasses * featureIndex * numBins
648610
+ arr(arrIndex).toInt * numClasses
649611
val labelInt = label.toInt
650-
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1)
612+
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
651613
}
652614

653615
def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
@@ -667,10 +629,10 @@ object DecisionTree extends Serializable with Logging {
667629
val labelInt = label.toInt
668630
if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
669631
agg(aggIndex + binIndex)
670-
= agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1)
632+
= agg(aggIndex + binIndex) + 1
671633
} else {
672634
agg(rightChildShift + aggIndex + binIndex)
673-
= agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1)
635+
= agg(rightChildShift + aggIndex + binIndex) + 1
674636
}
675637
binIndex += 1
676638
}

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
3939
* zero-indexed.
4040
* @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
4141
* 128 MB.
42-
* @param labelWeights A map storing weights applied to each label for handling unbalanced
43-
* datasets. For example, an entry (n -> k) implies the a weight of k is
44-
* applied to an instance with label n. It's important to note that labels
45-
* are zero-index and take values 0, 1, 2, ... , numClasses.
4642
*
4743
*/
4844
@Experimental
@@ -54,8 +50,7 @@ class Strategy (
5450
val maxBins: Int = 100,
5551
val quantileCalculationStrategy: QuantileStrategy = Sort,
5652
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
57-
val maxMemoryInMB: Int = 128,
58-
val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable {
53+
val maxMemoryInMB: Int = 128) extends Serializable {
5954

6055
require(numClassesForClassification >= 2)
6156
val isMulticlassClassification = numClassesForClassification > 2

0 commit comments

Comments
 (0)