@@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging {
157157 validationInput : RDD [LabeledPoint ],
158158 boostingStrategy : BoostingStrategy ,
159159 validate : Boolean ): GradientBoostedTreesModel = {
160- val sc = input.sparkContext
161160 val timer = new TimeTracker ()
162161 timer.start(" total" )
163162 timer.start(" init" )
@@ -166,8 +165,8 @@ object GradientBoostedTrees extends Logging {
166165
167166 // Initialize gradient boosting parameters
168167 val numIterations = boostingStrategy.numIterations
169- val baseLearners = sc.broadcast( new Array [DecisionTreeModel ](numIterations) )
170- val baseLearnerWeights = sc.broadcast( new Array [Double ](numIterations) )
168+ val baseLearners = new Array [DecisionTreeModel ](numIterations)
169+ val baseLearnerWeights = new Array [Double ](numIterations)
171170 val loss = boostingStrategy.loss
172171 val learningRate = boostingStrategy.learningRate
173172 // Prepare strategy for individual trees, which use regression with variance impurity.
@@ -193,9 +192,10 @@ object GradientBoostedTrees extends Logging {
193192 timer.start(" building tree 0" )
194193 val firstTreeModel = new DecisionTree (treeStrategy).run(data)
195194 val firstTreeWeight = 1.0
196- baseLearners.value(0 ) = firstTreeModel
197- baseLearnerWeights.value(0 ) = firstTreeWeight
198- val startingModel = new GradientBoostedTreesModel (Regression , Array (firstTreeModel), Array (1.0 ))
195+ baseLearners(0 ) = firstTreeModel
196+ baseLearnerWeights(0 ) = firstTreeWeight
197+ val startingModel = new GradientBoostedTreesModel (
198+ Regression , Array (firstTreeModel), baseLearnerWeights.slice(0 , 1 ))
199199
200200 var predError : RDD [(Double , Double )] = GradientBoostedTreesModel .
201201 computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
@@ -223,18 +223,18 @@ object GradientBoostedTrees extends Logging {
223223 val model = new DecisionTree (treeStrategy).run(data)
224224 timer.stop(s " building tree $m" )
225225 // Create partial model
226- baseLearners.value (m) = model
226+ baseLearners(m) = model
227227 // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
228228 // Technically, the weight should be optimized for the particular loss.
229229 // However, the behavior should be reasonable, though not optimal.
230- baseLearnerWeights.value (m) = learningRate
230+ baseLearnerWeights(m) = learningRate
231231 // Note: A model of type regression is used since we require raw prediction
232232 val partialModel = new GradientBoostedTreesModel (
233- Regression , baseLearners.value. slice(0 , m + 1 ),
234- baseLearnerWeights.value. slice(0 , m + 1 ))
233+ Regression , baseLearners.slice(0 , m + 1 ),
234+ baseLearnerWeights.slice(0 , m + 1 ))
235235
236236 predError = GradientBoostedTreesModel .updatePredictionError(
237- input, predError, m, baseLearnerWeights, baseLearners, loss)
237+ input, predError, baseLearnerWeights(m), baseLearners(m) , loss)
238238 logDebug(" error of gbt = " + predError.values.mean())
239239
240240 if (validate) {
@@ -244,13 +244,13 @@ object GradientBoostedTrees extends Logging {
244244 // We want the model returned corresponding to the best validation error.
245245
246246 validatePredError = GradientBoostedTreesModel .updatePredictionError(
247- validationInput, validatePredError, m, baseLearnerWeights, baseLearners, loss)
247+ validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m) , loss)
248248 val currentValidateError = validatePredError.values.mean()
249249 if (bestValidateError - currentValidateError < validationTol) {
250250 return new GradientBoostedTreesModel (
251251 boostingStrategy.treeStrategy.algo,
252- baseLearners.value. slice(0 , bestM),
253- baseLearnerWeights.value. slice(0 , bestM))
252+ baseLearners.slice(0 , bestM),
253+ baseLearnerWeights.slice(0 , bestM))
254254 } else if (currentValidateError < bestValidateError) {
255255 bestValidateError = currentValidateError
256256 bestM = m + 1
@@ -270,11 +270,11 @@ object GradientBoostedTrees extends Logging {
270270 if (validate) {
271271 new GradientBoostedTreesModel (
272272 boostingStrategy.treeStrategy.algo,
273- baseLearners.value. slice(0 , bestM),
274- baseLearnerWeights.value. slice(0 , bestM))
273+ baseLearners.slice(0 , bestM),
274+ baseLearnerWeights.slice(0 , bestM))
275275 } else {
276276 new GradientBoostedTreesModel (
277- boostingStrategy.treeStrategy.algo, baseLearners.value , baseLearnerWeights.value )
277+ boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
278278 }
279279 }
280280
0 commit comments