Skip to content

Commit 923dbf6

Browse files
committed
[SPARK-5972] Cache residuals and gradient in GBT during training and validation
1 parent 424e987 commit 923dbf6

File tree

6 files changed

+94
-48
lines changed

6 files changed

+94
-48
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,24 @@ object GradientBoostedTrees extends Logging {
195195
baseLearners(0) = firstTreeModel
196196
baseLearnerWeights(0) = 1.0
197197
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
198-
logDebug("error of gbt = " + loss.computeError(startingModel, input))
198+
199+
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
200+
computeInitialPredictionAndError(input, 1.0, firstTreeModel, loss)
201+
logDebug("error of gbt = " + predError.values.mean())
199202

200203
// Note: A model of type regression is used since we require raw prediction
201204
timer.stop("building tree 0")
202205

203-
var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
206+
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
207+
computeInitialPredictionAndError(validationInput, 1.0, firstTreeModel, loss)
208+
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
204209
var bestM = 1
205210

206211
// psuedo-residual for second iteration
207-
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
208-
point.features))
212+
data = predError.zip(input).map {
213+
case ((pred, _), point) => LabeledPoint(loss.gradient(pred, point.label), point.features)
214+
}
215+
209216
var m = 1
210217
while (m < numIterations) {
211218
timer.start(s"building tree $m")
@@ -223,14 +230,20 @@ object GradientBoostedTrees extends Logging {
223230
// Note: A model of type regression is used since we require raw prediction
224231
val partialModel = new GradientBoostedTreesModel(
225232
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
226-
logDebug("error of gbt = " + loss.computeError(partialModel, input))
233+
234+
predError = GradientBoostedTreesModel.updatePredictionError(
235+
input, predError, learningRate, model, loss)
236+
logDebug("error of gbt = " + predError.values.mean())
227237

228238
if (validate) {
229239
// Stop training early if
230240
// 1. Reduction in error is less than the validationTol or
231241
// 2. If the error increases, that is if the model is overfit.
232242
// We want the model returned corresponding to the best validation error.
233-
val currentValidateError = loss.computeError(partialModel, validationInput)
243+
244+
validatePredError = GradientBoostedTreesModel.updatePredictionError(
245+
validationInput, validatePredError, learningRate, model, loss)
246+
val currentValidateError = validatePredError.values.mean()
234247
if (bestValidateError - currentValidateError < validationTol) {
235248
return new GradientBoostedTreesModel(
236249
boostingStrategy.treeStrategy.algo,
@@ -242,8 +255,9 @@ object GradientBoostedTrees extends Logging {
242255
}
243256
}
244257
// Update data with pseudo-residuals
245-
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
246-
point.features))
258+
data = predError.zip(input).map {
259+
case ((pred, _), point) => LabeledPoint(-loss.gradient(pred, point.label), point.features)
260+
}
247261
m += 1
248262
}
249263

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ object AbsoluteError extends Loss {
3737
* Method to calculate the gradients for the gradient boosting calculation for least
3838
* absolute error calculation.
3939
* The gradient with respect to F(x) is: sign(F(x) - y)
40-
* @param model Ensemble model
41-
* @param point Instance of the training dataset
40+
* @param prediction Predicted point
41+
* @param label True label.
4242
* @return Loss gradient
4343
*/
44-
override def gradient(
45-
model: TreeEnsembleModel,
46-
point: LabeledPoint): Double = {
47-
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
44+
override def gradient(prediction: Double, label: Double): Double = {
45+
if (label - prediction < 0) 1.0 else -1.0
4846
}
4947

5048
override def computeError(prediction: Double, label: Double): Double = {

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,15 @@ object LogLoss extends Loss {
3939
* Method to calculate the loss gradients for the gradient boosting calculation for binary
4040
* classification
4141
* The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
42-
* @param model Ensemble model
43-
* @param point Instance of the training dataset
42+
* @param prediction Predicted point
43+
* @param label True label.
4444
* @return Loss gradient
4545
*/
46-
override def gradient(
47-
model: TreeEnsembleModel,
48-
point: LabeledPoint): Double = {
49-
val prediction = model.predict(point.features)
50-
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
46+
override def gradient(prediction: Double, label: Double): Double = {
47+
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
5148
}
5249

50+
5351
override def computeError(prediction: Double, label: Double): Double = {
5452
val margin = 2.0 * label * prediction
5553
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@ trait Loss extends Serializable {
3131

3232
/**
3333
* Method to calculate the gradients for the gradient boosting calculation.
34-
* @param model Model of the weak learner.
35-
* @param point Instance of the training dataset.
34+
* @param prediction Predicted feature
35+
* @param label true label.
3636
* @return Loss gradient.
3737
*/
38-
def gradient(
39-
model: TreeEnsembleModel,
40-
point: LabeledPoint): Double
38+
def gradient(prediction: Double, label: Double): Double
4139

4240
/**
4341
* Method to calculate error of the base learner for the gradient boosting calculation.

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ object SquaredError extends Loss {
3737
* Method to calculate the gradients for the gradient boosting calculation for least
3838
* squares error calculation.
3939
* The gradient with respect to F(x) is: - 2 (y - F(x))
40-
* @param model Ensemble model
41-
* @param point Instance of the training dataset
40+
* @param prediction Predicted point
41+
* @param label True label.
4242
* @return Loss gradient
4343
*/
44-
override def gradient(
45-
model: TreeEnsembleModel,
46-
point: LabeledPoint): Double = {
47-
2.0 * (model.predict(point.features) - point.label)
44+
override def gradient(prediction: Double, label: Double): Double = {
45+
2.0 * (prediction - label)
4846
}
4947

5048
override def computeError(prediction: Double, label: Double): Double = {

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,29 +131,19 @@ class GradientBoostedTreesModel(
131131
val numIterations = trees.length
132132
val evaluationArray = Array.fill(numIterations)(0.0)
133133

134-
var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
135-
val pred = treeWeights(0) * trees(0).predict(i.features)
136-
val error = loss.computeError(pred, i.label)
137-
(pred, error)
138-
}
134+
var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
135+
remappedData, treeWeights(0), trees(0), loss)
136+
139137
evaluationArray(0) = predictionAndError.values.mean()
140138

141139
// Avoid the model being copied across numIterations.
142140
val broadcastTrees = sc.broadcast(trees)
143141
val broadcastWeights = sc.broadcast(treeWeights)
144142

145143
(1 until numIterations).map { nTree =>
146-
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
147-
val currentTree = broadcastTrees.value(nTree)
148-
val currentTreeWeight = broadcastWeights.value(nTree)
149-
iter.map {
150-
case (point, (pred, error)) => {
151-
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
152-
val newError = loss.computeError(newPred, point.label)
153-
(newPred, newError)
154-
}
155-
}
156-
}
144+
predictionAndError = GradientBoostedTreesModel.updatePredictionError(
145+
remappedData, predictionAndError, broadcastWeights.value(nTree),
146+
broadcastTrees.value(nTree), loss)
157147
evaluationArray(nTree) = predictionAndError.values.mean()
158148
}
159149

@@ -166,6 +156,56 @@ class GradientBoostedTreesModel(
166156

167157
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
168158

159+
/**
160+
* Method to compute initial error and prediction as a RDD for the first
161+
* iteration of gradient boosting.
162+
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
163+
* @param initTreeWeight: learning rate assigned to the first tree.
164+
* @param initTree: first DecisionTreeModel
165+
* @param loss: evaluation metric
166+
* @return a RDD with each element being a zip of the prediction and error
167+
* corresponding to every sample.
168+
*/
169+
def computeInitialPredictionAndError(
170+
data: RDD[LabeledPoint],
171+
initTreeWeight: Double,
172+
initTree: DecisionTreeModel, loss: Loss): RDD[(Double, Double)] = {
173+
data.map { i =>
174+
val pred = initTreeWeight * initTree.predict(i.features)
175+
val error = loss.computeError(pred, i.label)
176+
(pred, error)
177+
}
178+
}
179+
180+
/**
181+
* Method to update a zipped predictionError RDD
182+
* (as obtained with computeInitialPredictionAndError)
183+
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
184+
* @param predictionAndError: predictionError RDD
185+
* @param currentTreeWeight: learning rate.
186+
* @param currentTree: first DecisionTree
187+
* @param loss: evaluation metric
188+
* @return a RDD with each element being a zip of the prediction and error
189+
* corresponing to each sample.
190+
*/
191+
def updatePredictionError(
192+
data: RDD[LabeledPoint],
193+
predictionAndError: RDD[(Double, Double)],
194+
currentTreeWeight: Double,
195+
currentTree: DecisionTreeModel,
196+
loss: Loss): RDD[(Double, Double)] = {
197+
198+
data.zip(predictionAndError).mapPartitions { iter =>
199+
iter.map {
200+
case (point, (pred, error)) => {
201+
val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
202+
val newError = loss.computeError(newPred, point.label)
203+
(newPred, newError)
204+
}
205+
}
206+
}
207+
}
208+
169209
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
170210
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
171211
val classNameV1_0 = SaveLoadV1_0.thisClassName

0 commit comments

Comments
 (0)