@@ -21,18 +21,17 @@ import org.apache.spark.Logging
2121import org .apache .spark .annotation .Experimental
2222import org .apache .spark .api .java .JavaRDD
2323import org .apache .spark .mllib .regression .LabeledPoint
24- import org .apache .spark .mllib .tree .configuration .Algo ._
2524import org .apache .spark .mllib .tree .configuration .BoostingStrategy
25+ import org .apache .spark .mllib .tree .configuration .Algo ._
2626import org .apache .spark .mllib .tree .configuration .EnsembleCombiningStrategy .Sum
2727import org .apache .spark .mllib .tree .impl .TimeTracker
28- import org .apache .spark .mllib .tree .model .{WeightedEnsembleModel , DecisionTreeModel }
28+ import org .apache .spark .mllib .tree .model .{DecisionTreeModel , TreeEnsembleModel }
2929import org .apache .spark .rdd .RDD
3030import org .apache .spark .storage .StorageLevel
3131
3232/**
3333 * :: Experimental ::
34- * A class that implements Stochastic Gradient Boosting
35- * for regression and binary classification problems.
34+ * A class that implements Stochastic Gradient Boosting for regression and binary classification.
3635 *
3736 * The implementation is based upon:
3837 * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
@@ -45,146 +44,84 @@ import org.apache.spark.storage.StorageLevel
4544 * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
4645 * Running with those losses will likely behave reasonably, but lacks the same guarantees.
4746 *
48- * @param boostingStrategy Parameters for the gradient boosting algorithm
47+ * @param boostingStrategy Parameters for the gradient boosting algorithm.
4948 */
5049@ Experimental
51- class GradientBoosting (
50+ class GradientBoostedTrees (
5251 private val boostingStrategy : BoostingStrategy ) extends Serializable with Logging {
5352
54- boostingStrategy.weakLearnerParams.algo = Regression
55- boostingStrategy.weakLearnerParams.impurity = impurity.Variance
56-
57- // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
58- boostingStrategy.weakLearnerParams.numClassesForClassification =
59- boostingStrategy.numClassesForClassification
60-
61- boostingStrategy.assertValid()
62-
6353 /**
6454 * Method to train a gradient boosting model
6555 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
6656 * @return WeightedEnsembleModel that can be used for prediction
6757 */
68- def train (input : RDD [LabeledPoint ]): WeightedEnsembleModel = {
69- val algo = boostingStrategy.algo
58+ def train (input : RDD [LabeledPoint ]): TreeEnsembleModel = {
59+ val algo = boostingStrategy.treeStrategy. algo
7060 algo match {
71- case Regression => GradientBoosting .boost(input, boostingStrategy)
61+ case Regression => GradientBoostedTrees .boost(input, boostingStrategy)
7262 case Classification =>
7363 // Map labels to -1, +1 so binary classification can be treated as regression.
7464 val remappedInput = input.map(x => new LabeledPoint ((x.label * 2 ) - 1 , x.features))
75- GradientBoosting .boost(remappedInput, boostingStrategy)
65+ GradientBoostedTrees .boost(remappedInput, boostingStrategy)
7666 case _ =>
7767 throw new IllegalArgumentException (s " $algo is not supported by the gradient boosting. " )
7868 }
7969 }
80-
8170}
8271
8372
84- object GradientBoosting extends Logging {
73+ object GradientBoostedTrees extends Logging {
8574
8675 /**
8776 * Method to train a gradient boosting model.
8877 *
89- * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor ]]
90- * is recommended to clearly specify regression.
91- * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier ]]
92- * is recommended to clearly specify regression.
93- *
9478 * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
9579 * For classification, labels should take values {0, 1, ..., numClasses-1}.
9680 * For regression, labels are real numbers.
9781 * @param boostingStrategy Configuration options for the boosting algorithm.
98- * @return WeightedEnsembleModel that can be used for prediction
82+ * @return a tree ensemble model that can be used for prediction
9983 */
10084 def train (
10185 input : RDD [LabeledPoint ],
102- boostingStrategy : BoostingStrategy ): WeightedEnsembleModel = {
103- new GradientBoosting (boostingStrategy).train(input)
104- }
105-
106- /**
107- * Method to train a gradient boosting classification model.
108- *
109- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
110- * For classification, labels should take values {0, 1, ..., numClasses-1}.
111- * For regression, labels are real numbers.
112- * @param boostingStrategy Configuration options for the boosting algorithm.
113- * @return WeightedEnsembleModel that can be used for prediction
114- */
115- def trainClassifier (
116- input : RDD [LabeledPoint ],
117- boostingStrategy : BoostingStrategy ): WeightedEnsembleModel = {
118- val algo = boostingStrategy.algo
119- require(algo == Classification , s " Only Classification algo supported. Provided algo is $algo. " )
120- new GradientBoosting (boostingStrategy).train(input)
121- }
122-
123- /**
124- * Method to train a gradient boosting regression model.
125- *
126- * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
127- * For classification, labels should take values {0, 1, ..., numClasses-1}.
128- * For regression, labels are real numbers.
129- * @param boostingStrategy Configuration options for the boosting algorithm.
130- * @return WeightedEnsembleModel that can be used for prediction
131- */
132- def trainRegressor (
133- input : RDD [LabeledPoint ],
134- boostingStrategy : BoostingStrategy ): WeightedEnsembleModel = {
135- val algo = boostingStrategy.algo
136- require(algo == Regression , s " Only Regression algo supported. Provided algo is $algo. " )
137- new GradientBoosting (boostingStrategy).train(input)
86+ boostingStrategy : BoostingStrategy ): TreeEnsembleModel = {
87+ new GradientBoostedTrees (boostingStrategy).train(input)
13888 }
13989
14090 /**
141- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting $#train ]]
91+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees $#train ]]
14292 */
14393 def train (
144- input : JavaRDD [LabeledPoint ],
145- boostingStrategy : BoostingStrategy ): WeightedEnsembleModel = {
146- train(input.rdd, boostingStrategy)
147- }
148-
149- /**
150- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier ]]
151- */
152- def trainClassifier (
15394 input : JavaRDD [LabeledPoint ],
154- boostingStrategy : BoostingStrategy ): WeightedEnsembleModel = {
155- trainClassifier(input.rdd, boostingStrategy)
156- }
157-
158- /**
159- * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor ]]
160- */
161- def trainRegressor (
162- input : JavaRDD [LabeledPoint ],
163- boostingStrategy : BoostingStrategy ): WeightedEnsembleModel = {
164- trainRegressor(input.rdd, boostingStrategy)
95+ boostingStrategy : BoostingStrategy ): TreeEnsembleModel = {
96+ train(input.rdd, boostingStrategy)
16597 }
16698
16799 /**
168100 * Internal method for performing regression using trees as base learners.
169101 * @param input training dataset
170102 * @param boostingStrategy boosting parameters
171- * @return
103+ * @return a tree ensemble model that can be used for prediction
172104 */
173105 private def boost (
174106 input : RDD [LabeledPoint ],
175- boostingStrategy : BoostingStrategy ): WeightedEnsembleModel = {
107+ boostingStrategy : BoostingStrategy ): TreeEnsembleModel = {
176108
177109 val timer = new TimeTracker ()
178110 timer.start(" total" )
179111 timer.start(" init" )
180112
113+ boostingStrategy.assertValid()
114+
181115 // Initialize gradient boosting parameters
182116 val numIterations = boostingStrategy.numIterations
183117 val baseLearners = new Array [DecisionTreeModel ](numIterations)
184118 val baseLearnerWeights = new Array [Double ](numIterations)
185119 val loss = boostingStrategy.loss
186120 val learningRate = boostingStrategy.learningRate
187- val strategy = boostingStrategy.weakLearnerParams
121+ val ensembleStrategy = boostingStrategy.treeStrategy.copy
122+ ensembleStrategy.algo = Regression
123+ ensembleStrategy.impurity = impurity.Variance
124+ ensembleStrategy.assertValid()
188125
189126 // Cache input
190127 if (input.getStorageLevel == StorageLevel .NONE ) {
@@ -200,11 +137,10 @@ object GradientBoosting extends Logging {
200137
201138 // Initialize tree
202139 timer.start(" building tree 0" )
203- val firstTreeModel = new DecisionTree (strategy ).train(data)
140+ val firstTreeModel = new DecisionTree (ensembleStrategy ).train(data)
204141 baseLearners(0 ) = firstTreeModel
205142 baseLearnerWeights(0 ) = 1.0
206- val startingModel = new WeightedEnsembleModel (Array (firstTreeModel), Array (1.0 ), Regression ,
207- Sum )
143+ val startingModel = new TreeEnsembleModel (Array (firstTreeModel), Array (1.0 ), Regression , Sum )
208144 logDebug(" error of gbt = " + loss.computeError(startingModel, input))
209145 // Note: A model of type regression is used since we require raw prediction
210146 timer.stop(" building tree 0" )
@@ -219,7 +155,7 @@ object GradientBoosting extends Logging {
219155 logDebug(" ###################################################" )
220156 logDebug(" Gradient boosting tree iteration " + m)
221157 logDebug(" ###################################################" )
222- val model = new DecisionTree (strategy ).train(data)
158+ val model = new DecisionTree (ensembleStrategy ).train(data)
223159 timer.stop(s " building tree $m" )
224160 // Create partial model
225161 baseLearners(m) = model
@@ -228,7 +164,7 @@ object GradientBoosting extends Logging {
228164 // However, the behavior should be reasonable, though not optimal.
229165 baseLearnerWeights(m) = learningRate
230166 // Note: A model of type regression is used since we require raw prediction
231- val partialModel = new WeightedEnsembleModel (baseLearners.slice(0 , m + 1 ),
167+ val partialModel = new TreeEnsembleModel (baseLearners.slice(0 , m + 1 ),
232168 baseLearnerWeights.slice(0 , m + 1 ), Regression , Sum )
233169 logDebug(" error of gbt = " + loss.computeError(partialModel, input))
234170 // Update data with pseudo-residuals
@@ -242,8 +178,6 @@ object GradientBoosting extends Logging {
242178 logInfo(" Internal timing for DecisionTree:" )
243179 logInfo(s " $timer" )
244180
245- new WeightedEnsembleModel (baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum )
246-
181+ new TreeEnsembleModel (baseLearners, baseLearnerWeights, boostingStrategy.treeStrategy.algo, Sum )
247182 }
248-
249183}
0 commit comments