@@ -157,7 +157,7 @@ object GradientBoostedTrees extends Logging {
157157 validationInput : RDD [LabeledPoint ],
158158 boostingStrategy : BoostingStrategy ,
159159 validate : Boolean ): GradientBoostedTreesModel = {
160-
160+ val sc = input.sparkContext
161161 val timer = new TimeTracker ()
162162 timer.start(" total" )
163163 timer.start(" init" )
@@ -166,8 +166,8 @@ object GradientBoostedTrees extends Logging {
166166
167167 // Initialize gradient boosting parameters
168168 val numIterations = boostingStrategy.numIterations
169- val baseLearners = new Array [DecisionTreeModel ](numIterations)
170- val baseLearnerWeights = new Array [Double ](numIterations)
169+ val baseLearners = sc.broadcast( new Array [DecisionTreeModel ](numIterations) )
170+ val baseLearnerWeights = sc.broadcast( new Array [Double ](numIterations) )
171171 val loss = boostingStrategy.loss
172172 val learningRate = boostingStrategy.learningRate
173173 // Prepare strategy for individual trees, which use regression with variance impurity.
@@ -192,25 +192,26 @@ object GradientBoostedTrees extends Logging {
192192 // Initialize tree
193193 timer.start(" building tree 0" )
194194 val firstTreeModel = new DecisionTree (treeStrategy).run(data)
195- baseLearners(0 ) = firstTreeModel
196- baseLearnerWeights(0 ) = 1.0
195+ val firstTreeWeight = 1.0
196+ baseLearners.value(0 ) = firstTreeModel
197+ baseLearnerWeights.value(0 ) = firstTreeWeight
197198 val startingModel = new GradientBoostedTreesModel (Regression , Array (firstTreeModel), Array (1.0 ))
198199
199200 var predError : RDD [(Double , Double )] = GradientBoostedTreesModel .
200- computeInitialPredictionAndError(input, 1.0 , firstTreeModel, loss)
201+ computeInitialPredictionAndError(input, firstTreeWeight , firstTreeModel, loss)
201202 logDebug(" error of gbt = " + predError.values.mean())
202203
203204 // Note: A model of type regression is used since we require raw prediction
204205 timer.stop(" building tree 0" )
205206
206207 var validatePredError : RDD [(Double , Double )] = GradientBoostedTreesModel .
207- computeInitialPredictionAndError(validationInput, 1.0 , firstTreeModel, loss)
208+ computeInitialPredictionAndError(validationInput, firstTreeWeight , firstTreeModel, loss)
208209 var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
209210 var bestM = 1
210211
211- // psuedo -residual for second iteration
212- data = predError.zip(input).map {
213- case ((pred, _), point) => LabeledPoint (loss.gradient(pred, point.label), point.features)
212+ // pseudo -residual for second iteration
213+ data = predError.zip(input).map { case ((pred, _), point) =>
214+ LabeledPoint (- loss.gradient(pred, point.label), point.features)
214215 }
215216
216217 var m = 1
@@ -222,17 +223,18 @@ object GradientBoostedTrees extends Logging {
222223 val model = new DecisionTree (treeStrategy).run(data)
223224 timer.stop(s " building tree $m" )
224225 // Create partial model
225- baseLearners(m) = model
226+ baseLearners.value (m) = model
226227 // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
227228 // Technically, the weight should be optimized for the particular loss.
228229 // However, the behavior should be reasonable, though not optimal.
229- baseLearnerWeights(m) = learningRate
230+ baseLearnerWeights.value (m) = learningRate
230231 // Note: A model of type regression is used since we require raw prediction
231232 val partialModel = new GradientBoostedTreesModel (
232- Regression , baseLearners.slice(0 , m + 1 ), baseLearnerWeights.slice(0 , m + 1 ))
233+ Regression , baseLearners.value.slice(0 , m + 1 ),
234+ baseLearnerWeights.value.slice(0 , m + 1 ))
233235
234236 predError = GradientBoostedTreesModel .updatePredictionError(
235- input, predError, learningRate, model , loss)
237+ input, predError, m, baseLearnerWeights, baseLearners , loss)
236238 logDebug(" error of gbt = " + predError.values.mean())
237239
238240 if (validate) {
@@ -242,21 +244,21 @@ object GradientBoostedTrees extends Logging {
242244 // We want the model returned corresponding to the best validation error.
243245
244246 validatePredError = GradientBoostedTreesModel .updatePredictionError(
245- validationInput, validatePredError, learningRate, model , loss)
247+ validationInput, validatePredError, m, baseLearnerWeights, baseLearners , loss)
246248 val currentValidateError = validatePredError.values.mean()
247249 if (bestValidateError - currentValidateError < validationTol) {
248250 return new GradientBoostedTreesModel (
249251 boostingStrategy.treeStrategy.algo,
250- baseLearners.slice(0 , bestM),
251- baseLearnerWeights.slice(0 , bestM))
252+ baseLearners.value. slice(0 , bestM),
253+ baseLearnerWeights.value. slice(0 , bestM))
252254 } else if (currentValidateError < bestValidateError) {
253255 bestValidateError = currentValidateError
254256 bestM = m + 1
255257 }
256258 }
257259 // Update data with pseudo-residuals
258- data = predError.zip(input).map {
259- case ((pred, _), point) => LabeledPoint (- loss.gradient(pred, point.label), point.features)
260+ data = predError.zip(input).map { case ((pred, _), point) =>
261+ LabeledPoint (- loss.gradient(pred, point.label), point.features)
260262 }
261263 m += 1
262264 }
@@ -268,11 +270,11 @@ object GradientBoostedTrees extends Logging {
268270 if (validate) {
269271 new GradientBoostedTreesModel (
270272 boostingStrategy.treeStrategy.algo,
271- baseLearners.slice(0 , bestM),
272- baseLearnerWeights.slice(0 , bestM))
273+ baseLearners.value. slice(0 , bestM),
274+ baseLearnerWeights.value. slice(0 , bestM))
273275 } else {
274276 new GradientBoostedTreesModel (
275- boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
277+ boostingStrategy.treeStrategy.algo, baseLearners.value , baseLearnerWeights.value )
276278 }
277279 }
278280
0 commit comments