Skip to content

Commit 4aae3b7

Browse files
committed
add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy
1 parent ea4c467 commit 4aae3b7

File tree

7 files changed

+121
-122
lines changed

7 files changed

+121
-122
lines changed

examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import org.apache.spark.mllib.regression.LabeledPoint;
3030
import org.apache.spark.mllib.tree.GradientBoostedTrees;
3131
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
32-
import org.apache.spark.mllib.tree.model.TreeEnsembleModel;
32+
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
3333
import org.apache.spark.mllib.util.MLUtils;
3434

3535
/**
@@ -76,7 +76,7 @@ public static void main(String[] args) {
7676
boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
7777

7878
// Train a GradientBoosting model for classification.
79-
final TreeEnsembleModel model = GradientBoostedTrees.train(data, boostingStrategy);
79+
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
8080

8181
// Evaluate model on training instances and compute training error
8282
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -95,7 +95,7 @@ public static void main(String[] args) {
9595
System.out.println("Learned classification tree model:\n" + model);
9696
} else if (algo.equals("Regression")) {
9797
// Train a GradientBoosting model for classification.
98-
final TreeEnsembleModel model = GradientBoostedTrees.train(data, boostingStrategy);
98+
final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
9999

100100
// Evaluate model on training instances and compute training error
101101
JavaPairRDD<Double, Double> predictionAndLabel =

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import scopt.OptionParser
2222
import org.apache.spark.{SparkConf, SparkContext}
2323
import org.apache.spark.SparkContext._
2424
import org.apache.spark.mllib.evaluation.MulticlassMetrics
25+
import org.apache.spark.mllib.linalg.Vector
2526
import org.apache.spark.mllib.regression.LabeledPoint
26-
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
27+
import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
2728
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
2829
import org.apache.spark.mllib.tree.configuration.Algo._
29-
import org.apache.spark.mllib.tree.model.{TreeEnsembleModel, DecisionTreeModel}
3030
import org.apache.spark.mllib.util.MLUtils
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.util.Utils
@@ -349,24 +349,14 @@ object DecisionTreeRunner {
349349
sc.stop()
350350
}
351351

352-
/**
353-
* Calculates the mean squared error for regression.
354-
*/
355-
private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
356-
data.map { y =>
357-
val err = tree.predict(y.features) - y.label
358-
err * err
359-
}.mean()
360-
}
361-
362352
/**
363353
* Calculates the mean squared error for regression.
364354
*/
365355
private[mllib] def meanSquaredError(
366-
tree: TreeEnsembleModel,
356+
model: { def predict(features: Vector): Double },
367357
data: RDD[LabeledPoint]): Double = {
368358
data.map { y =>
369-
val err = tree.predict(y.features) - y.label
359+
val err = model.predict(y.features) - y.label
370360
err * err
371361
}.mean()
372362
}

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ import org.apache.spark.api.java.JavaRDD
2323
import org.apache.spark.mllib.regression.LabeledPoint
2424
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
2525
import org.apache.spark.mllib.tree.configuration.Algo._
26-
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
2726
import org.apache.spark.mllib.tree.impl.TimeTracker
28-
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, TreeEnsembleModel}
27+
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
2928
import org.apache.spark.rdd.RDD
3029
import org.apache.spark.storage.StorageLevel
3130

@@ -53,9 +52,9 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
5352
/**
5453
* Method to train a gradient boosting model
5554
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
56-
* @return WeightedEnsembleModel that can be used for prediction
55+
* @return a gradient boosted trees model that can be used for prediction
5756
*/
58-
def run(input: RDD[LabeledPoint]): TreeEnsembleModel = {
57+
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
5958
val algo = boostingStrategy.treeStrategy.algo
6059
algo match {
6160
case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
@@ -71,7 +70,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
7170
/**
7271
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
7372
*/
74-
def run(input: JavaRDD[LabeledPoint]): TreeEnsembleModel = {
73+
def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
7574
run(input.rdd)
7675
}
7776
}
@@ -86,11 +85,11 @@ object GradientBoostedTrees extends Logging {
8685
* For classification, labels should take values {0, 1, ..., numClasses-1}.
8786
* For regression, labels are real numbers.
8887
* @param boostingStrategy Configuration options for the boosting algorithm.
89-
* @return a tree ensemble model that can be used for prediction
88+
* @return a gradient boosted trees model that can be used for prediction
9089
*/
9190
def train(
9291
input: RDD[LabeledPoint],
93-
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
92+
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
9493
new GradientBoostedTrees(boostingStrategy).run(input)
9594
}
9695

@@ -99,19 +98,19 @@ object GradientBoostedTrees extends Logging {
9998
*/
10099
def train(
101100
input: JavaRDD[LabeledPoint],
102-
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
101+
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
103102
train(input.rdd, boostingStrategy)
104103
}
105104

106105
/**
107106
* Internal method for performing regression using trees as base learners.
108107
* @param input training dataset
109108
* @param boostingStrategy boosting parameters
110-
* @return a tree ensemble model that can be used for prediction
109+
* @return a gradient boosted trees model that can be used for prediction
111110
*/
112111
private def boost(
113112
input: RDD[LabeledPoint],
114-
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
113+
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
115114

116115
val timer = new TimeTracker()
117116
timer.start("total")
@@ -148,7 +147,7 @@ object GradientBoostedTrees extends Logging {
148147
val firstTreeModel = new DecisionTree(ensembleStrategy).run(data)
149148
baseLearners(0) = firstTreeModel
150149
baseLearnerWeights(0) = 1.0
151-
val startingModel = new TreeEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, Sum)
150+
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
152151
logDebug("error of gbt = " + loss.computeError(startingModel, input))
153152
// Note: A model of type regression is used since we require raw prediction
154153
timer.stop("building tree 0")
@@ -172,8 +171,8 @@ object GradientBoostedTrees extends Logging {
172171
// However, the behavior should be reasonable, though not optimal.
173172
baseLearnerWeights(m) = learningRate
174173
// Note: A model of type regression is used since we require raw prediction
175-
val partialModel = new TreeEnsembleModel(baseLearners.slice(0, m + 1),
176-
baseLearnerWeights.slice(0, m + 1), Regression, Sum)
174+
val partialModel = new GradientBoostedTreesModel(
175+
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
177176
logDebug("error of gbt = " + loss.computeError(partialModel, input))
178177
// Update data with pseudo-residuals
179178
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
@@ -186,6 +185,7 @@ object GradientBoostedTrees extends Logging {
186185
logInfo("Internal timing for DecisionTree:")
187186
logInfo(s"$timer")
188187

189-
new TreeEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.treeStrategy.algo, Sum)
188+
new GradientBoostedTreesModel(
189+
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
190190
}
191191
}

mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@
1717

1818
package org.apache.spark.mllib.tree
1919

20-
import scala.collection.JavaConverters._
2120
import scala.collection.mutable
21+
import scala.collection.JavaConverters._
2222

2323
import org.apache.spark.Logging
2424
import org.apache.spark.annotation.Experimental
2525
import org.apache.spark.api.java.JavaRDD
2626
import org.apache.spark.mllib.regression.LabeledPoint
27+
import org.apache.spark.mllib.tree.configuration.Strategy
2728
import org.apache.spark.mllib.tree.configuration.Algo._
2829
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
29-
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
30-
import org.apache.spark.mllib.tree.configuration.Strategy
31-
import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
30+
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache,
31+
TimeTracker, TreePoint}
3232
import org.apache.spark.mllib.tree.impurity.Impurities
3333
import org.apache.spark.mllib.tree.model._
3434
import org.apache.spark.rdd.RDD
@@ -79,9 +79,9 @@ private class RandomForest (
7979
/**
8080
* Method to train a decision tree model over an RDD
8181
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
82-
* @return WeightedEnsembleModel that can be used for prediction
82+
* @return a random forest model that can be used for prediction
8383
*/
84-
def run(input: RDD[LabeledPoint]): TreeEnsembleModel = {
84+
def run(input: RDD[LabeledPoint]): RandomForestModel = {
8585

8686
val timer = new TimeTracker()
8787

@@ -212,8 +212,7 @@ private class RandomForest (
212212
}
213213

214214
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
215-
val treeWeights = Array.fill[Double](numTrees)(1.0)
216-
new TreeEnsembleModel(trees, treeWeights, strategy.algo, Average)
215+
new RandomForestModel(strategy.algo, trees)
217216
}
218217

219218
}
@@ -234,14 +233,14 @@ object RandomForest extends Serializable with Logging {
234233
* if numTrees > 1 (forest) set to "sqrt" for classification and
235234
* to "onethird" for regression.
236235
* @param seed Random seed for bootstrapping and choosing feature subsets.
237-
* @return WeightedEnsembleModel that can be used for prediction
236+
* @return a random forest model that can be used for prediction
238237
*/
239238
def trainClassifier(
240239
input: RDD[LabeledPoint],
241240
strategy: Strategy,
242241
numTrees: Int,
243242
featureSubsetStrategy: String,
244-
seed: Int): TreeEnsembleModel = {
243+
seed: Int): RandomForestModel = {
245244
require(strategy.algo == Classification,
246245
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
247246
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -272,7 +271,7 @@ object RandomForest extends Serializable with Logging {
272271
* @param maxBins maximum number of bins used for splitting features
273272
* (suggested value: 100)
274273
* @param seed Random seed for bootstrapping and choosing feature subsets.
275-
* @return WeightedEnsembleModel that can be used for prediction
274+
* @return a random forest model that can be used for prediction
276275
*/
277276
def trainClassifier(
278277
input: RDD[LabeledPoint],
@@ -283,7 +282,7 @@ object RandomForest extends Serializable with Logging {
283282
impurity: String,
284283
maxDepth: Int,
285284
maxBins: Int,
286-
seed: Int = Utils.random.nextInt()): TreeEnsembleModel = {
285+
seed: Int = Utils.random.nextInt()): RandomForestModel = {
287286
val impurityType = Impurities.fromString(impurity)
288287
val strategy = new Strategy(Classification, impurityType, maxDepth,
289288
numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
@@ -302,7 +301,7 @@ object RandomForest extends Serializable with Logging {
302301
impurity: String,
303302
maxDepth: Int,
304303
maxBins: Int,
305-
seed: Int): TreeEnsembleModel = {
304+
seed: Int): RandomForestModel = {
306305
trainClassifier(input.rdd, numClassesForClassification,
307306
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
308307
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -322,14 +321,14 @@ object RandomForest extends Serializable with Logging {
322321
* if numTrees > 1 (forest) set to "sqrt" for classification and
323322
* to "onethird" for regression.
324323
* @param seed Random seed for bootstrapping and choosing feature subsets.
325-
* @return WeightedEnsembleModel that can be used for prediction
324+
* @return a random forest model that can be used for prediction
326325
*/
327326
def trainRegressor(
328327
input: RDD[LabeledPoint],
329328
strategy: Strategy,
330329
numTrees: Int,
331330
featureSubsetStrategy: String,
332-
seed: Int): TreeEnsembleModel = {
331+
seed: Int): RandomForestModel = {
333332
require(strategy.algo == Regression,
334333
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
335334
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -359,7 +358,7 @@ object RandomForest extends Serializable with Logging {
359358
* @param maxBins maximum number of bins used for splitting features
360359
* (suggested value: 100)
361360
* @param seed Random seed for bootstrapping and choosing feature subsets.
362-
* @return WeightedEnsembleModel that can be used for prediction
361+
* @return a random forest model that can be used for prediction
363362
*/
364363
def trainRegressor(
365364
input: RDD[LabeledPoint],
@@ -369,7 +368,7 @@ object RandomForest extends Serializable with Logging {
369368
impurity: String,
370369
maxDepth: Int,
371370
maxBins: Int,
372-
seed: Int = Utils.random.nextInt()): TreeEnsembleModel = {
371+
seed: Int = Utils.random.nextInt()): RandomForestModel = {
373372
val impurityType = Impurities.fromString(impurity)
374373
val strategy = new Strategy(Regression, impurityType, maxDepth,
375374
0, maxBins, Sort, categoricalFeaturesInfo)
@@ -479,5 +478,4 @@ object RandomForest extends Serializable with Logging {
479478
3 * totalBins
480479
}
481480
}
482-
483481
}

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,10 @@
1717

1818
package org.apache.spark.mllib.tree.configuration
1919

20-
import org.apache.spark.annotation.DeveloperApi
21-
2220
/**
23-
* :: DeveloperApi ::
2421
* Enum to select ensemble combining strategy for base learners
2522
*/
26-
@DeveloperApi
27-
object EnsembleCombiningStrategy extends Enumeration {
23+
private[tree] object EnsembleCombiningStrategy extends Enumeration {
2824
type EnsembleCombiningStrategy = Value
29-
val Sum, Average = Value
25+
val Average, Sum, Vote = Value
3026
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
package org.apache.spark.mllib.tree.model
1919

20-
import org.apache.spark.api.java.JavaRDD
2120
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.api.java.JavaRDD
22+
import org.apache.spark.mllib.linalg.Vector
2223
import org.apache.spark.mllib.tree.configuration.Algo._
2324
import org.apache.spark.rdd.RDD
24-
import org.apache.spark.mllib.linalg.Vector
2525

2626
/**
2727
* :: Experimental ::

0 commit comments

Comments
 (0)