Skip to content

Commit 5869533

Browse files
committed
Access broadcasted values locally and other minor changes
1 parent 923dbf6 commit 5869533

File tree

5 files changed

+53
-47
lines changed

5 files changed

+53
-47
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ object GradientBoostedTrees extends Logging {
157157
validationInput: RDD[LabeledPoint],
158158
boostingStrategy: BoostingStrategy,
159159
validate: Boolean): GradientBoostedTreesModel = {
160-
160+
val sc = input.sparkContext
161161
val timer = new TimeTracker()
162162
timer.start("total")
163163
timer.start("init")
@@ -166,8 +166,8 @@ object GradientBoostedTrees extends Logging {
166166

167167
// Initialize gradient boosting parameters
168168
val numIterations = boostingStrategy.numIterations
169-
val baseLearners = new Array[DecisionTreeModel](numIterations)
170-
val baseLearnerWeights = new Array[Double](numIterations)
169+
val baseLearners = sc.broadcast(new Array[DecisionTreeModel](numIterations))
170+
val baseLearnerWeights = sc.broadcast(new Array[Double](numIterations))
171171
val loss = boostingStrategy.loss
172172
val learningRate = boostingStrategy.learningRate
173173
// Prepare strategy for individual trees, which use regression with variance impurity.
@@ -192,25 +192,26 @@ object GradientBoostedTrees extends Logging {
192192
// Initialize tree
193193
timer.start("building tree 0")
194194
val firstTreeModel = new DecisionTree(treeStrategy).run(data)
195-
baseLearners(0) = firstTreeModel
196-
baseLearnerWeights(0) = 1.0
195+
val firstTreeWeight = 1.0
196+
baseLearners.value(0) = firstTreeModel
197+
baseLearnerWeights.value(0) = firstTreeWeight
197198
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
198199

199200
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
200-
computeInitialPredictionAndError(input, 1.0, firstTreeModel, loss)
201+
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
201202
logDebug("error of gbt = " + predError.values.mean())
202203

203204
// Note: A model of type regression is used since we require raw prediction
204205
timer.stop("building tree 0")
205206

206207
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
207-
computeInitialPredictionAndError(validationInput, 1.0, firstTreeModel, loss)
208+
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
208209
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
209210
var bestM = 1
210211

211-
// psuedo-residual for second iteration
212-
data = predError.zip(input).map {
213-
case ((pred, _), point) => LabeledPoint(loss.gradient(pred, point.label), point.features)
212+
// pseudo-residual for second iteration
213+
data = predError.zip(input).map { case ((pred, _), point) =>
214+
LabeledPoint(-loss.gradient(pred, point.label), point.features)
214215
}
215216

216217
var m = 1
@@ -222,17 +223,18 @@ object GradientBoostedTrees extends Logging {
222223
val model = new DecisionTree(treeStrategy).run(data)
223224
timer.stop(s"building tree $m")
224225
// Create partial model
225-
baseLearners(m) = model
226+
baseLearners.value(m) = model
226227
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
227228
// Technically, the weight should be optimized for the particular loss.
228229
// However, the behavior should be reasonable, though not optimal.
229-
baseLearnerWeights(m) = learningRate
230+
baseLearnerWeights.value(m) = learningRate
230231
// Note: A model of type regression is used since we require raw prediction
231232
val partialModel = new GradientBoostedTreesModel(
232-
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
233+
Regression, baseLearners.value.slice(0, m + 1),
234+
baseLearnerWeights.value.slice(0, m + 1))
233235

234236
predError = GradientBoostedTreesModel.updatePredictionError(
235-
input, predError, learningRate, model, loss)
237+
input, predError, m, baseLearnerWeights, baseLearners, loss)
236238
logDebug("error of gbt = " + predError.values.mean())
237239

238240
if (validate) {
@@ -242,21 +244,21 @@ object GradientBoostedTrees extends Logging {
242244
// We want the model returned corresponding to the best validation error.
243245

244246
validatePredError = GradientBoostedTreesModel.updatePredictionError(
245-
validationInput, validatePredError, learningRate, model, loss)
247+
validationInput, validatePredError, m, baseLearnerWeights, baseLearners, loss)
246248
val currentValidateError = validatePredError.values.mean()
247249
if (bestValidateError - currentValidateError < validationTol) {
248250
return new GradientBoostedTreesModel(
249251
boostingStrategy.treeStrategy.algo,
250-
baseLearners.slice(0, bestM),
251-
baseLearnerWeights.slice(0, bestM))
252+
baseLearners.value.slice(0, bestM),
253+
baseLearnerWeights.value.slice(0, bestM))
252254
} else if (currentValidateError < bestValidateError) {
253255
bestValidateError = currentValidateError
254256
bestM = m + 1
255257
}
256258
}
257259
// Update data with pseudo-residuals
258-
data = predError.zip(input).map {
259-
case ((pred, _), point) => LabeledPoint(-loss.gradient(pred, point.label), point.features)
260+
data = predError.zip(input).map { case ((pred, _), point) =>
261+
LabeledPoint(-loss.gradient(pred, point.label), point.features)
260262
}
261263
m += 1
262264
}
@@ -268,11 +270,11 @@ object GradientBoostedTrees extends Logging {
268270
if (validate) {
269271
new GradientBoostedTreesModel(
270272
boostingStrategy.treeStrategy.algo,
271-
baseLearners.slice(0, bestM),
272-
baseLearnerWeights.slice(0, bestM))
273+
baseLearners.value.slice(0, bestM),
274+
baseLearnerWeights.value.slice(0, bestM))
273275
} else {
274276
new GradientBoostedTreesModel(
275-
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
277+
boostingStrategy.treeStrategy.algo, baseLearners.value, baseLearnerWeights.value)
276278
}
277279
}
278280

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ object AbsoluteError extends Loss {
3737
* Method to calculate the gradients for the gradient boosting calculation for least
3838
* absolute error calculation.
3939
* The gradient with respect to F(x) is: sign(F(x) - y)
40-
* @param prediction Predicted point
40+
* @param prediction Predicted label.
4141
* @param label True label.
4242
* @return Loss gradient
4343
*/

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,14 @@ object LogLoss extends Loss {
3939
* Method to calculate the loss gradients for the gradient boosting calculation for binary
4040
* classification
4141
* The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
42-
* @param prediction Predicted point
42+
* @param prediction Predicted label.
4343
* @param label True label.
4444
* @return Loss gradient
4545
*/
4646
override def gradient(prediction: Double, label: Double): Double = {
4747
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
4848
}
4949

50-
5150
override def computeError(prediction: Double, label: Double): Double = {
5251
val margin = 2.0 * label * prediction
5352
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ object SquaredError extends Loss {
3737
* Method to calculate the gradients for the gradient boosting calculation for least
3838
* squares error calculation.
3939
* The gradient with respect to F(x) is: - 2 (y - F(x))
40-
* @param prediction Predicted point
40+
* @param prediction Predicted label.
4141
* @param label True label.
4242
* @return Loss gradient
4343
*/

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.json4s.jackson.JsonMethods._
2727
import org.apache.spark.{Logging, SparkContext}
2828
import org.apache.spark.annotation.Experimental
2929
import org.apache.spark.api.java.JavaRDD
30+
import org.apache.spark.broadcast.Broadcast
3031
import org.apache.spark.mllib.linalg.Vector
3132
import org.apache.spark.mllib.regression.LabeledPoint
3233
import org.apache.spark.mllib.tree.configuration.Algo
@@ -142,8 +143,7 @@ class GradientBoostedTreesModel(
142143

143144
(1 until numIterations).map { nTree =>
144145
predictionAndError = GradientBoostedTreesModel.updatePredictionError(
145-
remappedData, predictionAndError, broadcastWeights.value(nTree),
146-
broadcastTrees.value(nTree), loss)
146+
remappedData, predictionAndError, nTree, broadcastWeights, broadcastTrees, loss)
147147
evaluationArray(nTree) = predictionAndError.values.mean()
148148
}
149149

@@ -157,49 +157,54 @@ class GradientBoostedTreesModel(
157157
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
158158

159159
/**
160-
* Method to compute initial error and prediction as a RDD for the first
160+
* Compute the initial predictions and errors for a dataset for the first
161161
* iteration of gradient boosting.
162-
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
162+
* @param Training data.
163163
* @param initTreeWeight: learning rate assigned to the first tree.
164-
* @param initTree: first DecisionTreeModel
165-
* @param loss: evaluation metric
164+
* @param initTree: first DecisionTreeModel.
165+
* @param loss: evaluation metric.
166166
* @return a RDD with each element being a zip of the prediction and error
167167
* corresponding to every sample.
168168
*/
169169
def computeInitialPredictionAndError(
170170
data: RDD[LabeledPoint],
171171
initTreeWeight: Double,
172-
initTree: DecisionTreeModel, loss: Loss): RDD[(Double, Double)] = {
173-
data.map { i =>
174-
val pred = initTreeWeight * initTree.predict(i.features)
175-
val error = loss.computeError(pred, i.label)
172+
initTree: DecisionTreeModel,
173+
loss: Loss): RDD[(Double, Double)] = {
174+
data.map { lp =>
175+
val pred = initTreeWeight * initTree.predict(lp.features)
176+
val error = loss.computeError(pred, lp.label)
176177
(pred, error)
177178
}
178179
}
179180

180181
/**
181-
* Method to update a zipped predictionError RDD
182+
* Update a zipped predictionError RDD
182183
* (as obtained with computeInitialPredictionAndError)
183-
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
184+
* @param training data.
184185
* @param predictionAndError: predictionError RDD
185-
* @param currentTreeWeight: learning rate.
186-
* @param currentTree: first DecisionTree
187-
* @param loss: evaluation metric
186+
* @param nTree: tree index.
187+
* @param TreeWeights: Broadcasted learning rates.
188+
* @param Trees: Broadcasted trees.
189+
* @param loss: evaluation metric.
188190
* @return a RDD with each element being a zip of the prediction and error
189-
* corresponing to each sample.
191+
* corresponding to each sample.
190192
*/
191193
def updatePredictionError(
192194
data: RDD[LabeledPoint],
193195
predictionAndError: RDD[(Double, Double)],
194-
currentTreeWeight: Double,
195-
currentTree: DecisionTreeModel,
196+
nTree: Int,
197+
TreeWeights: Broadcast[Array[Double]],
198+
Trees: Broadcast[Array[DecisionTreeModel]],
196199
loss: Loss): RDD[(Double, Double)] = {
197200

198201
data.zip(predictionAndError).mapPartitions { iter =>
202+
val currentTreeWeight = TreeWeights.value(nTree)
203+
val currentTree = Trees.value(nTree)
199204
iter.map {
200-
case (point, (pred, error)) => {
201-
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
202-
val newError = loss.computeError(newPred, point.label)
205+
case (lp, (pred, error)) => {
206+
val newPred = pred + currentTree.predict(lp.features) * currentTreeWeight
207+
val newError = loss.computeError(newPred, lp.label)
203208
(newPred, newError)
204209
}
205210
}

0 commit comments

Comments
 (0)