Skip to content

Commit 70d3b4c

Browse files
committed
Broadcast for each tree
1 parent 5869533 commit 70d3b4c

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging {
157157
validationInput: RDD[LabeledPoint],
158158
boostingStrategy: BoostingStrategy,
159159
validate: Boolean): GradientBoostedTreesModel = {
160-
val sc = input.sparkContext
161160
val timer = new TimeTracker()
162161
timer.start("total")
163162
timer.start("init")
@@ -166,8 +165,8 @@ object GradientBoostedTrees extends Logging {
166165

167166
// Initialize gradient boosting parameters
168167
val numIterations = boostingStrategy.numIterations
169-
val baseLearners = sc.broadcast(new Array[DecisionTreeModel](numIterations))
170-
val baseLearnerWeights = sc.broadcast(new Array[Double](numIterations))
168+
val baseLearners = new Array[DecisionTreeModel](numIterations)
169+
val baseLearnerWeights = new Array[Double](numIterations)
171170
val loss = boostingStrategy.loss
172171
val learningRate = boostingStrategy.learningRate
173172
// Prepare strategy for individual trees, which use regression with variance impurity.
@@ -193,9 +192,10 @@ object GradientBoostedTrees extends Logging {
193192
timer.start("building tree 0")
194193
val firstTreeModel = new DecisionTree(treeStrategy).run(data)
195194
val firstTreeWeight = 1.0
196-
baseLearners.value(0) = firstTreeModel
197-
baseLearnerWeights.value(0) = firstTreeWeight
198-
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
195+
baseLearners(0) = firstTreeModel
196+
baseLearnerWeights(0) = firstTreeWeight
197+
val startingModel = new GradientBoostedTreesModel(
198+
Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
199199

200200
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
201201
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
@@ -223,18 +223,18 @@ object GradientBoostedTrees extends Logging {
223223
val model = new DecisionTree(treeStrategy).run(data)
224224
timer.stop(s"building tree $m")
225225
// Create partial model
226-
baseLearners.value(m) = model
226+
baseLearners(m) = model
227227
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
228228
// Technically, the weight should be optimized for the particular loss.
229229
// However, the behavior should be reasonable, though not optimal.
230-
baseLearnerWeights.value(m) = learningRate
230+
baseLearnerWeights(m) = learningRate
231231
// Note: A model of type regression is used since we require raw prediction
232232
val partialModel = new GradientBoostedTreesModel(
233-
Regression, baseLearners.value.slice(0, m + 1),
234-
baseLearnerWeights.value.slice(0, m + 1))
233+
Regression, baseLearners.slice(0, m + 1),
234+
baseLearnerWeights.slice(0, m + 1))
235235

236236
predError = GradientBoostedTreesModel.updatePredictionError(
237-
input, predError, m, baseLearnerWeights, baseLearners, loss)
237+
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
238238
logDebug("error of gbt = " + predError.values.mean())
239239

240240
if (validate) {
@@ -244,13 +244,13 @@ object GradientBoostedTrees extends Logging {
244244
// We want the model returned corresponding to the best validation error.
245245

246246
validatePredError = GradientBoostedTreesModel.updatePredictionError(
247-
validationInput, validatePredError, m, baseLearnerWeights, baseLearners, loss)
247+
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
248248
val currentValidateError = validatePredError.values.mean()
249249
if (bestValidateError - currentValidateError < validationTol) {
250250
return new GradientBoostedTreesModel(
251251
boostingStrategy.treeStrategy.algo,
252-
baseLearners.value.slice(0, bestM),
253-
baseLearnerWeights.value.slice(0, bestM))
252+
baseLearners.slice(0, bestM),
253+
baseLearnerWeights.slice(0, bestM))
254254
} else if (currentValidateError < bestValidateError) {
255255
bestValidateError = currentValidateError
256256
bestM = m + 1
@@ -270,11 +270,11 @@ object GradientBoostedTrees extends Logging {
270270
if (validate) {
271271
new GradientBoostedTreesModel(
272272
boostingStrategy.treeStrategy.algo,
273-
baseLearners.value.slice(0, bestM),
274-
baseLearnerWeights.value.slice(0, bestM))
273+
baseLearners.slice(0, bestM),
274+
baseLearnerWeights.slice(0, bestM))
275275
} else {
276276
new GradientBoostedTreesModel(
277-
boostingStrategy.treeStrategy.algo, baseLearners.value, baseLearnerWeights.value)
277+
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
278278
}
279279
}
280280

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,12 @@ class GradientBoostedTreesModel(
137137

138138
evaluationArray(0) = predictionAndError.values.mean()
139139

140-
// Avoid the model being copied across numIterations.
141-
val broadcastTrees = sc.broadcast(trees)
142-
val broadcastWeights = sc.broadcast(treeWeights)
143-
144140
(1 until numIterations).map { nTree =>
145141
predictionAndError = GradientBoostedTreesModel.updatePredictionError(
146-
remappedData, predictionAndError, nTree, broadcastWeights, broadcastTrees, loss)
142+
remappedData, predictionAndError, treeWeights(nTree), trees(nTree), loss)
147143
evaluationArray(nTree) = predictionAndError.values.mean()
148144
}
149145

150-
broadcastTrees.unpersist()
151-
broadcastWeights.unpersist()
152146
evaluationArray
153147
}
154148

@@ -184,23 +178,26 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
184178
* @param training data.
185179
* @param predictionAndError: predictionError RDD
186180
* @param nTree: tree index.
187-
* @param TreeWeights: Broadcasted learning rates.
188-
* @param Trees: Broadcasted trees.
181+
* @param treeWeight: Learning rate.
182+
* @param tree: Tree using which the prediction and error should be updated.
189183
* @param loss: evaluation metric.
190184
* @return a RDD with each element being a zip of the prediction and error
191185
* corresponding to each sample.
192186
*/
193187
def updatePredictionError(
194188
data: RDD[LabeledPoint],
195189
predictionAndError: RDD[(Double, Double)],
196-
nTree: Int,
197-
TreeWeights: Broadcast[Array[Double]],
198-
Trees: Broadcast[Array[DecisionTreeModel]],
190+
treeWeight: Double,
191+
tree: DecisionTreeModel,
199192
loss: Loss): RDD[(Double, Double)] = {
200193

201-
data.zip(predictionAndError).mapPartitions { iter =>
202-
val currentTreeWeight = TreeWeights.value(nTree)
203-
val currentTree = Trees.value(nTree)
194+
val sc = data.sparkContext
195+
val broadcastedTreeWeight = sc.broadcast(treeWeight)
196+
val broadcastedTree = sc.broadcast(tree)
197+
198+
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
199+
val currentTreeWeight = broadcastedTreeWeight.value
200+
val currentTree = broadcastedTree.value
204201
iter.map {
205202
case (lp, (pred, error)) => {
206203
val newPred = pred + currentTree.predict(lp.features) * currentTreeWeight
@@ -209,6 +206,10 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
209206
}
210207
}
211208
}
209+
210+
broadcastedTreeWeight.unpersist()
211+
broadcastedTree.unpersist()
212+
newPredError
212213
}
213214

214215
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {

0 commit comments

Comments
 (0)