Skip to content

Commit 19030a5

Browse files
committed
update boosting public APIs
1 parent d75579d commit 19030a5

File tree

16 files changed

+180
-244
lines changed

16 files changed

+180
-244
lines changed

examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import org.apache.spark.api.java.function.Function2;
2828
import org.apache.spark.api.java.function.PairFunction;
2929
import org.apache.spark.mllib.regression.LabeledPoint;
30-
import org.apache.spark.mllib.tree.GradientBoosting;
30+
import org.apache.spark.mllib.tree.GradientBoostedTrees;
3131
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
32-
import org.apache.spark.mllib.tree.model.WeightedEnsembleModel;
32+
import org.apache.spark.mllib.tree.model.TreeEnsembleModel;
3333
import org.apache.spark.mllib.util.MLUtils;
3434

3535
/**
@@ -64,7 +64,7 @@ public static void main(String[] args) {
6464
// Note: All features are treated as continuous.
6565
BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
6666
boostingStrategy.setNumIterations(10);
67-
boostingStrategy.weakLearnerParams().setMaxDepth(5);
67+
boostingStrategy.treeStrategy().setMaxDepth(5);
6868

6969
if (algo.equals("Classification")) {
7070
// Compute the number of classes from the data.
@@ -76,7 +76,7 @@ public static void main(String[] args) {
7676
boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression
7777

7878
// Train a GradientBoosting model for classification.
79-
final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy);
79+
final TreeEnsembleModel model = GradientBoostedTrees.trainClassifier(data, boostingStrategy);
8080

8181
// Evaluate model on training instances and compute training error
8282
JavaPairRDD<Double, Double> predictionAndLabel =
@@ -95,7 +95,7 @@ public static void main(String[] args) {
9595
System.out.println("Learned classification tree model:\n" + model);
9696
} else if (algo.equals("Regression")) {
9797
// Train a GradientBoosting model for classification.
98-
final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy);
98+
final TreeEnsembleModel model = GradientBoostedTrees.trainRegressor(data, boostingStrategy);
9999

100100
// Evaluate model on training instances and compute training error
101101
JavaPairRDD<Double, Double> predictionAndLabel =

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
2626
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
2727
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
2828
import org.apache.spark.mllib.tree.configuration.Algo._
29-
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
29+
import org.apache.spark.mllib.tree.model.{TreeEnsembleModel, DecisionTreeModel}
3030
import org.apache.spark.mllib.util.MLUtils
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.util.Utils
@@ -363,7 +363,7 @@ object DecisionTreeRunner {
363363
* Calculates the mean squared error for regression.
364364
*/
365365
private[mllib] def meanSquaredError(
366-
tree: WeightedEnsembleModel,
366+
tree: TreeEnsembleModel,
367367
data: RDD[LabeledPoint]): Double = {
368368
data.map { y =>
369369
val err = tree.predict(y.features) - y.label

examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scopt.OptionParser
2121

2222
import org.apache.spark.{SparkConf, SparkContext}
2323
import org.apache.spark.mllib.evaluation.MulticlassMetrics
24-
import org.apache.spark.mllib.tree.GradientBoosting
24+
import org.apache.spark.mllib.tree.GradientBoostedTrees
2525
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
2626
import org.apache.spark.util.Utils
2727

@@ -103,14 +103,14 @@ object GradientBoostedTrees {
103103
params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
104104

105105
val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
106-
boostingStrategy.numClassesForClassification = numClasses
106+
boostingStrategy.treeStrategy.numClassesForClassification = numClasses
107107
boostingStrategy.numIterations = params.numIterations
108-
boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth
108+
boostingStrategy.treeStrategy.maxDepth = params.maxDepth
109109

110110
val randomSeed = Utils.random.nextInt()
111111
if (params.algo == "Classification") {
112112
val startTime = System.nanoTime()
113-
val model = GradientBoosting.trainClassifier(training, boostingStrategy)
113+
val model = GradientBoostedTrees.train(training, boostingStrategy)
114114
val elapsedTime = (System.nanoTime() - startTime) / 1e9
115115
println(s"Training time: $elapsedTime seconds")
116116
if (model.totalNumNodes < 30) {
@@ -127,7 +127,7 @@ object GradientBoostedTrees {
127127
println(s"Test accuracy = $testAccuracy")
128128
} else if (params.algo == "Regression") {
129129
val startTime = System.nanoTime()
130-
val model = GradientBoosting.trainRegressor(training, boostingStrategy)
130+
val model = GradientBoostedTrees.trainRegressor(training, boostingStrategy)
131131
val elapsedTime = (System.nanoTime() - startTime) / 1e9
132132
println(s"Training time: $elapsedTime seconds")
133133
if (model.totalNumNodes < 30) {

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
6262
// Note: random seed will not be used since numTrees = 1.
6363
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
6464
val rfModel = rf.train(input)
65-
rfModel.weakHypotheses(0)
65+
rfModel.trees(0)
6666
}
6767

6868
}

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala renamed to mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 29 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,17 @@ import org.apache.spark.Logging
2121
import org.apache.spark.annotation.Experimental
2222
import org.apache.spark.api.java.JavaRDD
2323
import org.apache.spark.mllib.regression.LabeledPoint
24-
import org.apache.spark.mllib.tree.configuration.Algo._
2524
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
25+
import org.apache.spark.mllib.tree.configuration.Algo._
2626
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
2727
import org.apache.spark.mllib.tree.impl.TimeTracker
28-
import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
28+
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, TreeEnsembleModel}
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.storage.StorageLevel
3131

3232
/**
3333
* :: Experimental ::
34-
* A class that implements Stochastic Gradient Boosting
35-
* for regression and binary classification problems.
34+
* A class that implements Stochastic Gradient Boosting for regression and binary classification.
3635
*
3736
* The implementation is based upon:
3837
* J.H. Friedman. "Stochastic Gradient Boosting." 1999.
@@ -45,146 +44,84 @@ import org.apache.spark.storage.StorageLevel
4544
* but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
4645
* Running with those losses will likely behave reasonably, but lacks the same guarantees.
4746
*
48-
* @param boostingStrategy Parameters for the gradient boosting algorithm
47+
* @param boostingStrategy Parameters for the gradient boosting algorithm.
4948
*/
5049
@Experimental
51-
class GradientBoosting (
50+
class GradientBoostedTrees (
5251
private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
5352

54-
boostingStrategy.weakLearnerParams.algo = Regression
55-
boostingStrategy.weakLearnerParams.impurity = impurity.Variance
56-
57-
// Ensure values for weak learner are the same as what is provided to the boosting algorithm.
58-
boostingStrategy.weakLearnerParams.numClassesForClassification =
59-
boostingStrategy.numClassesForClassification
60-
61-
boostingStrategy.assertValid()
62-
6353
/**
6454
* Method to train a gradient boosting model
6555
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
6656
* @return WeightedEnsembleModel that can be used for prediction
6757
*/
68-
def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
69-
val algo = boostingStrategy.algo
58+
def train(input: RDD[LabeledPoint]): TreeEnsembleModel = {
59+
val algo = boostingStrategy.treeStrategy.algo
7060
algo match {
71-
case Regression => GradientBoosting.boost(input, boostingStrategy)
61+
case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
7262
case Classification =>
7363
// Map labels to -1, +1 so binary classification can be treated as regression.
7464
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
75-
GradientBoosting.boost(remappedInput, boostingStrategy)
65+
GradientBoostedTrees.boost(remappedInput, boostingStrategy)
7666
case _ =>
7767
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
7868
}
7969
}
80-
8170
}
8271

8372

84-
object GradientBoosting extends Logging {
73+
object GradientBoostedTrees extends Logging {
8574

8675
/**
8776
* Method to train a gradient boosting model.
8877
*
89-
* Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
90-
* is recommended to clearly specify regression.
91-
* Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
92-
* is recommended to clearly specify regression.
93-
*
9478
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
9579
* For classification, labels should take values {0, 1, ..., numClasses-1}.
9680
* For regression, labels are real numbers.
9781
* @param boostingStrategy Configuration options for the boosting algorithm.
98-
* @return WeightedEnsembleModel that can be used for prediction
82+
* @return a tree ensemble model that can be used for prediction
9983
*/
10084
def train(
10185
input: RDD[LabeledPoint],
102-
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
103-
new GradientBoosting(boostingStrategy).train(input)
104-
}
105-
106-
/**
107-
* Method to train a gradient boosting classification model.
108-
*
109-
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
110-
* For classification, labels should take values {0, 1, ..., numClasses-1}.
111-
* For regression, labels are real numbers.
112-
* @param boostingStrategy Configuration options for the boosting algorithm.
113-
* @return WeightedEnsembleModel that can be used for prediction
114-
*/
115-
def trainClassifier(
116-
input: RDD[LabeledPoint],
117-
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
118-
val algo = boostingStrategy.algo
119-
require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
120-
new GradientBoosting(boostingStrategy).train(input)
121-
}
122-
123-
/**
124-
* Method to train a gradient boosting regression model.
125-
*
126-
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
127-
* For classification, labels should take values {0, 1, ..., numClasses-1}.
128-
* For regression, labels are real numbers.
129-
* @param boostingStrategy Configuration options for the boosting algorithm.
130-
* @return WeightedEnsembleModel that can be used for prediction
131-
*/
132-
def trainRegressor(
133-
input: RDD[LabeledPoint],
134-
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
135-
val algo = boostingStrategy.algo
136-
require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
137-
new GradientBoosting(boostingStrategy).train(input)
86+
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
87+
new GradientBoostedTrees(boostingStrategy).train(input)
13888
}
13989

14090
/**
141-
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]]
91+
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
14292
*/
14393
def train(
144-
input: JavaRDD[LabeledPoint],
145-
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
146-
train(input.rdd, boostingStrategy)
147-
}
148-
149-
/**
150-
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
151-
*/
152-
def trainClassifier(
15394
input: JavaRDD[LabeledPoint],
154-
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
155-
trainClassifier(input.rdd, boostingStrategy)
156-
}
157-
158-
/**
159-
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
160-
*/
161-
def trainRegressor(
162-
input: JavaRDD[LabeledPoint],
163-
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
164-
trainRegressor(input.rdd, boostingStrategy)
95+
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
96+
train(input.rdd, boostingStrategy)
16597
}
16698

16799
/**
168100
* Internal method for performing regression using trees as base learners.
169101
* @param input training dataset
170102
* @param boostingStrategy boosting parameters
171-
* @return
103+
* @return a tree ensemble model that can be used for prediction
172104
*/
173105
private def boost(
174106
input: RDD[LabeledPoint],
175-
boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
107+
boostingStrategy: BoostingStrategy): TreeEnsembleModel = {
176108

177109
val timer = new TimeTracker()
178110
timer.start("total")
179111
timer.start("init")
180112

113+
boostingStrategy.assertValid()
114+
181115
// Initialize gradient boosting parameters
182116
val numIterations = boostingStrategy.numIterations
183117
val baseLearners = new Array[DecisionTreeModel](numIterations)
184118
val baseLearnerWeights = new Array[Double](numIterations)
185119
val loss = boostingStrategy.loss
186120
val learningRate = boostingStrategy.learningRate
187-
val strategy = boostingStrategy.weakLearnerParams
121+
val ensembleStrategy = boostingStrategy.treeStrategy.copy
122+
ensembleStrategy.algo = Regression
123+
ensembleStrategy.impurity = impurity.Variance
124+
ensembleStrategy.assertValid()
188125

189126
// Cache input
190127
if (input.getStorageLevel == StorageLevel.NONE) {
@@ -200,11 +137,10 @@ object GradientBoosting extends Logging {
200137

201138
// Initialize tree
202139
timer.start("building tree 0")
203-
val firstTreeModel = new DecisionTree(strategy).train(data)
140+
val firstTreeModel = new DecisionTree(ensembleStrategy).train(data)
204141
baseLearners(0) = firstTreeModel
205142
baseLearnerWeights(0) = 1.0
206-
val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
207-
Sum)
143+
val startingModel = new TreeEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, Sum)
208144
logDebug("error of gbt = " + loss.computeError(startingModel, input))
209145
// Note: A model of type regression is used since we require raw prediction
210146
timer.stop("building tree 0")
@@ -219,7 +155,7 @@ object GradientBoosting extends Logging {
219155
logDebug("###################################################")
220156
logDebug("Gradient boosting tree iteration " + m)
221157
logDebug("###################################################")
222-
val model = new DecisionTree(strategy).train(data)
158+
val model = new DecisionTree(ensembleStrategy).train(data)
223159
timer.stop(s"building tree $m")
224160
// Create partial model
225161
baseLearners(m) = model
@@ -228,7 +164,7 @@ object GradientBoosting extends Logging {
228164
// However, the behavior should be reasonable, though not optimal.
229165
baseLearnerWeights(m) = learningRate
230166
// Note: A model of type regression is used since we require raw prediction
231-
val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
167+
val partialModel = new TreeEnsembleModel(baseLearners.slice(0, m + 1),
232168
baseLearnerWeights.slice(0, m + 1), Regression, Sum)
233169
logDebug("error of gbt = " + loss.computeError(partialModel, input))
234170
// Update data with pseudo-residuals
@@ -242,8 +178,6 @@ object GradientBoosting extends Logging {
242178
logInfo("Internal timing for DecisionTree:")
243179
logInfo(s"$timer")
244180

245-
new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
246-
181+
new TreeEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.treeStrategy.algo, Sum)
247182
}
248-
249183
}

0 commit comments

Comments
 (0)