Skip to content

Commit bc99ac6

Browse files
committed
Refactor the method and stuff
1 parent dbda033 commit bc99ac6

File tree

7 files changed

+75
-81
lines changed

7 files changed

+75
-81
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.apache.spark.mllib.tree.configuration.BoostingStrategy
2525
import org.apache.spark.mllib.tree.configuration.Algo._
2626
import org.apache.spark.mllib.tree.impl.TimeTracker
2727
import org.apache.spark.mllib.tree.impurity.Variance
28-
import org.apache.spark.mllib.tree.loss.Loss
2928
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
3029
import org.apache.spark.rdd.RDD
3130
import org.apache.spark.storage.StorageLevel
@@ -53,18 +52,14 @@ import org.apache.spark.storage.StorageLevel
5352
class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
5453
extends Serializable with Logging {
5554

56-
private val numIterations = boostingStrategy.numIterations
57-
private var baseLearners = new Array[DecisionTreeModel](numIterations)
58-
private var baseLearnerWeights = new Array[Double](numIterations)
59-
6055
/**
6156
* Method to train a gradient boosting model
6257
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
6358
* @return a gradient boosted trees model that can be used for prediction
6459
*/
6560
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
6661
val algo = boostingStrategy.treeStrategy.algo
67-
val fitGradientBoostingModel = algo match {
62+
algo match {
6863
case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
6964
case Classification =>
7065
// Map labels to -1, +1 so binary classification can be treated as regression.
@@ -74,42 +69,6 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
7469
case _ =>
7570
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
7671
}
77-
baseLearners = fitGradientBoostingModel.trees
78-
baseLearnerWeights = fitGradientBoostingModel.treeWeights
79-
fitGradientBoostingModel
80-
}
81-
82-
/**
83-
* Method to compute error or loss for every iteration of gradient boosting.
84-
* @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
85-
* @param loss: evaluation metric that defaults to boostingStrategy.loss
86-
* @return an array with index i having the losses or errors for the ensemble
87-
* containing trees 1 to i + 1
88-
*/
89-
def evaluateEachIteration(
90-
data: RDD[LabeledPoint],
91-
loss: Loss = boostingStrategy.loss) : Array[Double] = {
92-
93-
val algo = boostingStrategy.treeStrategy.algo
94-
val remappedData = algo match {
95-
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
96-
case _ => data
97-
}
98-
val initialTree = baseLearners(0)
99-
val evaluationArray = Array.fill(numIterations)(0.0)
100-
101-
// Initial weight is 1.0
102-
var predictionRDD = remappedData.map(i => initialTree.predict(i.features))
103-
evaluationArray(0) = loss.computeError(remappedData, predictionRDD)
104-
105-
(1 until numIterations).map {nTree =>
106-
predictionRDD = (remappedData zip predictionRDD) map {
107-
case (point, pred) =>
108-
pred + baseLearners(nTree).predict(point.features) * baseLearnerWeights(nTree)
109-
}
110-
evaluationArray(nTree) = loss.computeError(remappedData, predictionRDD)
111-
}
112-
evaluationArray
11372
}
11473

11574
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,13 @@ object AbsoluteError extends Loss {
6666
* Method to calculate loss when the predictions are already known.
6767
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
6868
* predicted values from previously fit trees.
69-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
70-
* @param prediction: RDD[Double] of predicted labels.
71-
* @return Mean absolute error of model on data
69+
* @param datum: LabeledPoint
70+
* @param prediction: Predicted label.
71+
* @return Absolute error of model on the given datapoint.
7272
*/
73-
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
74-
val errorAcrossSamples = (data zip prediction) map {
75-
case (yTrue, yPred) => {
76-
val err = yTrue.label - yPred
77-
math.abs(err)
78-
}
79-
}
80-
errorAcrossSamples.mean()
73+
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
74+
val err = datum.label - prediction
75+
math.abs(err)
8176
}
8277

8378
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,14 @@ object LogLoss extends Loss {
7171
* Method to calculate loss when the predictions are already known.
7272
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
7373
* predicted values from previously fit trees.
74-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
75-
* @param prediction: RDD[Double] of predicted labels.
76-
* @return Mean log loss of model on data
74+
* @param datum: LabeledPoint
75+
* @param prediction: Predicted label.
76+
* @return log loss of model on the datapoint.
7777
*/
78-
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
79-
val errorAcrossSamples = (data zip prediction) map {
80-
case (yTrue, yPred) =>
81-
val margin = 2.0 * yTrue.label * yPred
82-
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
83-
2.0 * MLUtils.log1pExp(-margin)
84-
}
85-
errorAcrossSamples.mean()
78+
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
79+
val margin = 2.0 * datum.label * prediction
80+
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
81+
2.0 * MLUtils.log1pExp(-margin)
8682
}
8783

8884
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ trait Loss extends Serializable {
5353
* Method to calculate loss when the predictions are already known.
5454
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
5555
* predicted values from previously fit trees.
56-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
57-
* @param prediction: RDD[Double] of predicted labels.
58-
* @return Measure of model error on data
56+
* @param datum: LabeledPoint
57+
* @param prediction: Predicted label.
58+
* @return Measure of model error on datapoint.
5959
*/
60-
def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]) : Double
60+
def computeError(datum: LabeledPoint, prediction: Double) : Double
6161

6262
}

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,13 @@ object SquaredError extends Loss {
6666
* Method to calculate loss when the predictions are already known.
6767
* Note: This method is used in the method evaluateEachIteration to avoid recomputing the
6868
* predicted values from previously fit trees.
69-
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
70-
* @param prediction: RDD[Double] of predicted labels.
71-
* @return Mean squared error of model on data
69+
* @param datum: LabeledPoint
70+
* @param prediction: Predicted label.
71+
* @return Mean squared error of model on datapoint.
7272
*/
73-
override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = {
74-
val errorAcrossSamples = (data zip prediction) map {
75-
case (yTrue, yPred) =>
76-
val err = yPred - yTrue.label
77-
err * err
78-
}
79-
errorAcrossSamples.mean()
73+
override def computeError(datum: LabeledPoint, prediction: Double): Double = {
74+
val err = prediction - datum.label
75+
err * err
8076
}
8177

8278
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext}
2828
import org.apache.spark.annotation.Experimental
2929
import org.apache.spark.api.java.JavaRDD
3030
import org.apache.spark.mllib.linalg.Vector
31+
import org.apache.spark.mllib.regression.LabeledPoint
3132
import org.apache.spark.mllib.tree.configuration.Algo
3233
import org.apache.spark.mllib.tree.configuration.Algo._
3334
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
35+
import org.apache.spark.mllib.tree.loss.Loss
3436
import org.apache.spark.mllib.util.{Loader, Saveable}
3537
import org.apache.spark.rdd.RDD
3638
import org.apache.spark.sql.SQLContext
@@ -108,6 +110,53 @@ class GradientBoostedTreesModel(
108110
}
109111

110112
override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
113+
114+
/**
115+
* Method to compute error or loss for every iteration of gradient boosting.
116+
* @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
117+
* @param loss: evaluation metric.
118+
* @return an array with index i having the losses or errors for the ensemble
119+
* containing trees 1 to i + 1
120+
*/
121+
def evaluateEachIteration(
122+
data: RDD[LabeledPoint],
123+
loss: Loss) : Array[Double] = {
124+
125+
val sc = data.sparkContext
126+
val remappedData = algo match {
127+
case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
128+
case _ => data
129+
}
130+
val initialTree = trees(0)
131+
val numIterations = trees.length
132+
val evaluationArray = Array.fill(numIterations)(0.0)
133+
134+
// Initial weight is 1.0
135+
var predictionErrorModel = remappedData.map {i =>
136+
val pred = initialTree.predict(i.features)
137+
val error = loss.computeError(i, pred)
138+
(pred, error)
139+
}
140+
evaluationArray(0) = predictionErrorModel.values.mean()
141+
142+
// Avoid the model being copied across numIterations.
143+
val broadcastTrees = sc.broadcast(trees)
144+
val broadcastWeights = sc.broadcast(treeWeights)
145+
146+
(1 until numIterations).map {nTree =>
147+
predictionErrorModel = (remappedData zip predictionErrorModel) map {
148+
case (point, (pred, error)) => {
149+
val newPred = pred + (
150+
broadcastTrees.value(nTree).predict(point.features) * broadcastWeights.value(nTree))
151+
val newError = loss.computeError(point, newPred)
152+
(newPred, newError)
153+
}
154+
}
155+
evaluationArray(nTree) = predictionErrorModel.values.mean()
156+
}
157+
evaluationArray
158+
}
159+
111160
}
112161

113162
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {

mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
179179
assert(numTrees !== numIterations)
180180

181181
// Test that it performs better on the validation dataset.
182-
val gbtModel = new GradientBoostedTrees(boostingStrategy)
183-
val gbt = gbtModel.run(trainRdd)
182+
val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
184183
val (errorWithoutValidation, errorWithValidation) = {
185184
if (algo == Classification) {
186185
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -193,7 +192,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
193192

194193
// Test that results from evaluateEachIteration comply with runWithValidation.
195194
// Note that convergenceTol is set to 0.0
196-
val evaluationArray = gbtModel.evaluateEachIteration(validateRdd)
195+
val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
197196
assert(evaluationArray.length === numIterations)
198197
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
199198
var i = 1

0 commit comments

Comments
 (0)