Skip to content

Commit 2a55cb4

Browse files
MechCoderjkbradley
authored andcommitted
[SPARK-5972] [MLlib] Cache residuals and gradient in GBT during training and validation
The previous PR #4906 helped to extract the learning curve giving the error for each iteration. This continues the work refactoring some code and extending the same logic during training and validation. Author: MechCoder <[email protected]> Closes #5330 from MechCoder/spark-5972 and squashes the following commits: 0b5d659 [MechCoder] minor 32d409d [MechCoder] EvaluateeachIteration and training cache should follow different paths d542bb0 [MechCoder] Remove unused imports and docs 58f4932 [MechCoder] Remove unpersist 70d3b4c [MechCoder] Broadcast for each tree 5869533 [MechCoder] Access broadcasted values locally and other minor changes 923dbf6 [MechCoder] [SPARK-5972] Cache residuals and gradient in GBT during training and validation
1 parent 3a205bb commit 2a55cb4

File tree

6 files changed

+105
-53
lines changed

6 files changed

+105
-53
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging {
157157
validationInput: RDD[LabeledPoint],
158158
boostingStrategy: BoostingStrategy,
159159
validate: Boolean): GradientBoostedTreesModel = {
160-
161160
val timer = new TimeTracker()
162161
timer.start("total")
163162
timer.start("init")
@@ -192,20 +191,29 @@ object GradientBoostedTrees extends Logging {
192191
// Initialize tree
193192
timer.start("building tree 0")
194193
val firstTreeModel = new DecisionTree(treeStrategy).run(data)
194+
val firstTreeWeight = 1.0
195195
baseLearners(0) = firstTreeModel
196-
baseLearnerWeights(0) = 1.0
197-
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
198-
logDebug("error of gbt = " + loss.computeError(startingModel, input))
196+
baseLearnerWeights(0) = firstTreeWeight
197+
val startingModel = new GradientBoostedTreesModel(
198+
Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
199+
200+
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
201+
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
202+
logDebug("error of gbt = " + predError.values.mean())
199203

200204
// Note: A model of type regression is used since we require raw prediction
201205
timer.stop("building tree 0")
202206

203-
var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
207+
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
208+
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
209+
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
204210
var bestM = 1
205211

206-
// psuedo-residual for second iteration
207-
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
208-
point.features))
212+
// pseudo-residual for second iteration
213+
data = predError.zip(input).map { case ((pred, _), point) =>
214+
LabeledPoint(-loss.gradient(pred, point.label), point.features)
215+
}
216+
209217
var m = 1
210218
while (m < numIterations) {
211219
timer.start(s"building tree $m")
@@ -222,15 +230,22 @@ object GradientBoostedTrees extends Logging {
222230
baseLearnerWeights(m) = learningRate
223231
// Note: A model of type regression is used since we require raw prediction
224232
val partialModel = new GradientBoostedTreesModel(
225-
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
226-
logDebug("error of gbt = " + loss.computeError(partialModel, input))
233+
Regression, baseLearners.slice(0, m + 1),
234+
baseLearnerWeights.slice(0, m + 1))
235+
236+
predError = GradientBoostedTreesModel.updatePredictionError(
237+
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
238+
logDebug("error of gbt = " + predError.values.mean())
227239

228240
if (validate) {
229241
// Stop training early if
230242
// 1. Reduction in error is less than the validationTol or
231243
// 2. If the error increases, that is if the model is overfit.
232244
// We want the model returned corresponding to the best validation error.
233-
val currentValidateError = loss.computeError(partialModel, validationInput)
245+
246+
validatePredError = GradientBoostedTreesModel.updatePredictionError(
247+
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
248+
val currentValidateError = validatePredError.values.mean()
234249
if (bestValidateError - currentValidateError < validationTol) {
235250
return new GradientBoostedTreesModel(
236251
boostingStrategy.treeStrategy.algo,
@@ -242,8 +257,9 @@ object GradientBoostedTrees extends Logging {
242257
}
243258
}
244259
// Update data with pseudo-residuals
245-
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
246-
point.features))
260+
data = predError.zip(input).map { case ((pred, _), point) =>
261+
LabeledPoint(-loss.gradient(pred, point.label), point.features)
262+
}
247263
m += 1
248264
}
249265

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ object AbsoluteError extends Loss {
3737
* Method to calculate the gradients for the gradient boosting calculation for least
3838
* absolute error calculation.
3939
* The gradient with respect to F(x) is: sign(F(x) - y)
40-
* @param model Ensemble model
41-
* @param point Instance of the training dataset
40+
* @param prediction Predicted label.
41+
* @param label True label.
4242
* @return Loss gradient
4343
*/
44-
override def gradient(
45-
model: TreeEnsembleModel,
46-
point: LabeledPoint): Double = {
47-
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
44+
override def gradient(prediction: Double, label: Double): Double = {
45+
if (label - prediction < 0) 1.0 else -1.0
4846
}
4947

5048
override def computeError(prediction: Double, label: Double): Double = {

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,12 @@ object LogLoss extends Loss {
3939
* Method to calculate the loss gradients for the gradient boosting calculation for binary
4040
* classification
4141
* The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
42-
* @param model Ensemble model
43-
* @param point Instance of the training dataset
42+
* @param prediction Predicted label.
43+
* @param label True label.
4444
* @return Loss gradient
4545
*/
46-
override def gradient(
47-
model: TreeEnsembleModel,
48-
point: LabeledPoint): Double = {
49-
val prediction = model.predict(point.features)
50-
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
46+
override def gradient(prediction: Double, label: Double): Double = {
47+
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
5148
}
5249

5350
override def computeError(prediction: Double, label: Double): Double = {

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@ trait Loss extends Serializable {
3131

3232
/**
3333
* Method to calculate the gradients for the gradient boosting calculation.
34-
* @param model Model of the weak learner.
35-
* @param point Instance of the training dataset.
34+
* @param prediction Predicted feature
35+
* @param label true label.
3636
* @return Loss gradient.
3737
*/
38-
def gradient(
39-
model: TreeEnsembleModel,
40-
point: LabeledPoint): Double
38+
def gradient(prediction: Double, label: Double): Double
4139

4240
/**
4341
* Method to calculate error of the base learner for the gradient boosting calculation.

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ object SquaredError extends Loss {
3737
* Method to calculate the gradients for the gradient boosting calculation for least
3838
* squares error calculation.
3939
* The gradient with respect to F(x) is: - 2 (y - F(x))
40-
* @param model Ensemble model
41-
* @param point Instance of the training dataset
40+
* @param prediction Predicted label.
41+
* @param label True label.
4242
* @return Loss gradient
4343
*/
44-
override def gradient(
45-
model: TreeEnsembleModel,
46-
point: LabeledPoint): Double = {
47-
2.0 * (model.predict(point.features) - point.label)
44+
override def gradient(prediction: Double, label: Double): Double = {
45+
2.0 * (prediction - label)
4846
}
4947

5048
override def computeError(prediction: Double, label: Double): Double = {

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,42 +130,87 @@ class GradientBoostedTreesModel(
130130

131131
val numIterations = trees.length
132132
val evaluationArray = Array.fill(numIterations)(0.0)
133+
val localTreeWeights = treeWeights
134+
135+
var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
136+
remappedData, localTreeWeights(0), trees(0), loss)
133137

134-
var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
135-
val pred = treeWeights(0) * trees(0).predict(i.features)
136-
val error = loss.computeError(pred, i.label)
137-
(pred, error)
138-
}
139138
evaluationArray(0) = predictionAndError.values.mean()
140139

141-
// Avoid the model being copied across numIterations.
142140
val broadcastTrees = sc.broadcast(trees)
143-
val broadcastWeights = sc.broadcast(treeWeights)
144-
145141
(1 until numIterations).map { nTree =>
146142
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
147143
val currentTree = broadcastTrees.value(nTree)
148-
val currentTreeWeight = broadcastWeights.value(nTree)
149-
iter.map {
150-
case (point, (pred, error)) => {
151-
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
152-
val newError = loss.computeError(newPred, point.label)
153-
(newPred, newError)
154-
}
144+
val currentTreeWeight = localTreeWeights(nTree)
145+
iter.map { case (point, (pred, error)) =>
146+
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
147+
val newError = loss.computeError(newPred, point.label)
148+
(newPred, newError)
155149
}
156150
}
157151
evaluationArray(nTree) = predictionAndError.values.mean()
158152
}
159153

160154
broadcastTrees.unpersist()
161-
broadcastWeights.unpersist()
162155
evaluationArray
163156
}
164157

165158
}
166159

167160
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
168161

162+
/**
163+
* Compute the initial predictions and errors for a dataset for the first
164+
* iteration of gradient boosting.
165+
* @param data: training data.
166+
* @param initTreeWeight: learning rate assigned to the first tree.
167+
* @param initTree: first DecisionTreeModel.
168+
* @param loss: evaluation metric.
169+
* @return a RDD with each element being a zip of the prediction and error
170+
* corresponding to every sample.
171+
*/
172+
def computeInitialPredictionAndError(
173+
data: RDD[LabeledPoint],
174+
initTreeWeight: Double,
175+
initTree: DecisionTreeModel,
176+
loss: Loss): RDD[(Double, Double)] = {
177+
data.map { lp =>
178+
val pred = initTreeWeight * initTree.predict(lp.features)
179+
val error = loss.computeError(pred, lp.label)
180+
(pred, error)
181+
}
182+
}
183+
184+
/**
185+
* Update a zipped predictionError RDD
186+
* (as obtained with computeInitialPredictionAndError)
187+
* @param data: training data.
188+
* @param predictionAndError: predictionError RDD
189+
* @param treeWeight: Learning rate.
190+
* @param tree: Tree using which the prediction and error should be updated.
191+
* @param loss: evaluation metric.
192+
* @return a RDD with each element being a zip of the prediction and error
193+
* corresponding to each sample.
194+
*/
195+
def updatePredictionError(
196+
data: RDD[LabeledPoint],
197+
predictionAndError: RDD[(Double, Double)],
198+
treeWeight: Double,
199+
tree: DecisionTreeModel,
200+
loss: Loss): RDD[(Double, Double)] = {
201+
202+
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
203+
iter.map {
204+
case (lp, (pred, error)) => {
205+
val newPred = pred + tree.predict(lp.features) * treeWeight
206+
val newError = loss.computeError(newPred, lp.label)
207+
(newPred, newError)
208+
}
209+
}
210+
}
211+
newPredError
212+
}
213+
169214
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
170215
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
171216
val classNameV1_0 = SaveLoadV1_0.thisClassName

0 commit comments

Comments
 (0)