Skip to content

Commit 6b912dc

Browse files
committed
added numclasses to tree runner, predict logic for multiclass, add multiclass option to train
1 parent 75f2bfc commit 6b912dc

File tree

4 files changed

+69
-24
lines changed

4 files changed

+69
-24
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ object DecisionTreeRunner {
4949
case class Params(
5050
input: String = null,
5151
algo: Algo = Classification,
52+
numClasses: Int = 2,
5253
maxDepth: Int = 5,
5354
impurity: ImpurityType = Gini,
5455
maxBins: Int = 100)
@@ -68,6 +69,9 @@ object DecisionTreeRunner {
6869
opt[Int]("maxDepth")
6970
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
7071
.action((x, c) => c.copy(maxDepth = x))
72+
opt[Int]("numClasses")
73+
.text(s"number of classes for classification, default: ${defaultParams.numClasses}")
74+
.action((x, c) => c.copy(numClasses = x))
7175
opt[Int]("maxBins")
7276
.text(s"max number of bins, default: ${defaultParams.maxBins}")
7377
.action((x, c) => c.copy(maxBins = x))
@@ -139,12 +143,8 @@ object DecisionTreeRunner {
139143
*/
140144
private def accuracyScore(
141145
model: DecisionTreeModel,
142-
data: RDD[LabeledPoint],
143-
threshold: Double = 0.5): Double = {
144-
def predictedValue(features: Vector): Double = {
145-
if (model.predict(features) < threshold) 0.0 else 1.0
146-
}
147-
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
146+
data: RDD[LabeledPoint]): Double = {
147+
val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
148148
val count = data.count()
149149
correctCount.toDouble / count
150150
}

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,19 +231,43 @@ object DecisionTree extends Serializable with Logging {
231231
* @param maxDepth maxDepth maximum depth of the tree
232232
* @return a DecisionTreeModel that can be used for prediction
233233
*/
234+
def train(
235+
input: RDD[LabeledPoint],
236+
algo: Algo,
237+
impurity: Impurity,
238+
maxDepth: Int): DecisionTreeModel = {
239+
val strategy = new Strategy(algo,impurity,maxDepth)
240+
// Converting from standard instance format to weighted input format for tree training
241+
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
242+
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
243+
}
244+
245+
/**
246+
* Method to train a decision tree model where the instances are represented as an RDD of
247+
* (label, features) pairs. The method supports binary classification and regression. For the
248+
* binary classification, the label for each instance should either be 0 or 1 to denote the two
249+
* classes.
250+
*
251+
* @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
252+
* training data
253+
* @param algo algorithm, classification or regression
254+
* @param impurity impurity criterion used for information gain calculation
255+
* @param maxDepth maxDepth maximum depth of the tree
256+
* @param numClasses number of classes for classification
257+
* @return a DecisionTreeModel that can be used for prediction
258+
*/
234259
def train(
235260
input: RDD[LabeledPoint],
236261
algo: Algo,
237262
impurity: Impurity,
238-
maxDepth: Int): DecisionTreeModel = {
239-
val strategy = new Strategy(algo,impurity,maxDepth)
263+
maxDepth: Int,
264+
numClasses: Int): DecisionTreeModel = {
265+
val strategy = new Strategy(algo,impurity,maxDepth,numClasses)
240266
// Converting from standard instance format to weighted input format for tree training
241267
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
242268
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
243269
}
244270

245-
// TODO: Add multiclass classification support
246-
247271
// TODO: Add sample weight support
248272

249273
/**
@@ -258,6 +282,7 @@ object DecisionTree extends Serializable with Logging {
258282
* @param algo classification or regression
259283
* @param impurity criterion used for information gain calculation
260284
* @param maxDepth maximum depth of the tree
285+
* @param numClasses number of classes for classification
261286
* @param maxBins maximum number of bins used for splitting features
262287
* @param quantileCalculationStrategy algorithm for calculating quantiles
263288
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -272,11 +297,12 @@ object DecisionTree extends Serializable with Logging {
272297
algo: Algo,
273298
impurity: Impurity,
274299
maxDepth: Int,
300+
numClasses: Int,
275301
maxBins: Int,
276302
quantileCalculationStrategy: QuantileStrategy,
277303
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
278-
val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
279-
categoricalFeaturesInfo)
304+
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
305+
quantileCalculationStrategy, categoricalFeaturesInfo)
280306
// Converting from standard instance format to weighted input format for tree training
281307
val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features))
282308
new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint])
@@ -737,10 +763,26 @@ object DecisionTree extends Serializable with Logging {
737763
}
738764
}
739765

740-
//TODO: Make multiclass modification here
741-
val predict = (leftCounts(1) + rightCounts(1)) / (leftTotalCount + rightTotalCount)
766+
val totalCount = leftTotalCount + rightTotalCount
742767

743-
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
768+
// Sum of count for each label
769+
val leftRightCounts: Array[Double]
770+
= leftCounts.zip(rightCounts)
771+
.map{case (leftCount, rightCount) => leftCount + rightCount}
772+
773+
def indexOfLargest(array: Seq[Double]): Int = {
774+
val result = array.foldLeft(-1,Double.MinValue,0) {
775+
case ((maxIndex, maxValue, currentIndex), currentValue) =>
776+
if(currentValue > maxValue) (currentIndex,currentValue,currentIndex+1)
777+
else (maxIndex,maxValue,currentIndex+1)
778+
}
779+
if (result._1 < 0) result._1 else 0
780+
}
781+
782+
val predict = indexOfLargest(leftRightCounts)
783+
val prob = leftRightCounts(predict) / totalCount
784+
785+
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
744786
case Regression =>
745787
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
746788
val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
@@ -793,8 +835,9 @@ object DecisionTree extends Serializable with Logging {
793835
/**
794836
* Extracts left and right split aggregates.
795837
* @param binData Array[Double] of size 2*numFeatures*numSplits
796-
* @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
797-
* Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
838+
* @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\],
839+
* Array[Array[Array[Double\]\]\]) where each array is of size(numFeature,
840+
* (numBins - 1), numClasses)
798841
*/
799842
def extractLeftRightNodeAggregates(
800843
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2828
* @param algo classification or regression
2929
* @param impurity criterion used for information gain calculation
3030
* @param maxDepth maximum depth of the tree
31+
* @param numClassesForClassification number of classes for classification. Default value is 2
32+
* leads to binary classification
3133
* @param maxBins maximum number of bins used for splitting features
3234
* @param quantileCalculationStrategy algorithm for calculating quantiles
3335
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
@@ -37,20 +39,18 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
3739
* zero-indexed.
3840
* @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
3941
* 128 MB.
40-
* @param numClassesForClassification number of classes for classification. Default value is 2
41-
* leads to binary classification
4242
*
4343
*/
4444
@Experimental
4545
class Strategy (
4646
val algo: Algo,
4747
val impurity: Impurity,
4848
val maxDepth: Int,
49+
val numClassesForClassification: Int = 2,
4950
val maxBins: Int = 100,
5051
val quantileCalculationStrategy: QuantileStrategy = Sort,
5152
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
52-
val maxMemoryInMB: Int = 128,
53-
val numClassesForClassification: Int = 2) extends Serializable {
53+
val maxMemoryInMB: Int = 128) extends Serializable {
5454

5555
require(numClassesForClassification >= 2)
5656
val isMultiClassification = numClassesForClassification > 2

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,19 @@ import org.apache.spark.annotation.DeveloperApi
2727
* @param leftImpurity left node impurity
2828
* @param rightImpurity right node impurity
2929
* @param predict predicted value
30+
* @param prob probability of the label (classification only)
3031
*/
3132
@DeveloperApi
3233
class InformationGainStats(
3334
val gain: Double,
3435
val impurity: Double,
3536
val leftImpurity: Double,
3637
val rightImpurity: Double,
37-
val predict: Double) extends Serializable {
38+
val predict: Double,
39+
val prob: Double = 0.0) extends Serializable {
3840

3941
override def toString = {
40-
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
41-
.format(gain, impurity, leftImpurity, rightImpurity, predict)
42+
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
43+
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
4244
}
4345
}

0 commit comments

Comments
 (0)