-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-6025] [MLlib] Add helper method evaluateEachIteration to extract learning curve #4906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.configuration.BoostingStrategy | |
| import org.apache.spark.mllib.tree.configuration.Algo._ | ||
| import org.apache.spark.mllib.tree.impl.TimeTracker | ||
| import org.apache.spark.mllib.tree.impurity.Variance | ||
| import org.apache.spark.mllib.tree.loss.Loss | ||
| import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.storage.StorageLevel | ||
|
|
@@ -52,14 +53,18 @@ import org.apache.spark.storage.StorageLevel | |
| class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) | ||
| extends Serializable with Logging { | ||
|
|
||
| private val numIterations = boostingStrategy.numIterations | ||
| private var baseLearners = new Array[DecisionTreeModel](numIterations) | ||
| private var baseLearnerWeights = new Array[Double](numIterations) | ||
|
|
||
| /** | ||
| * Method to train a gradient boosting model | ||
| * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
| * @return a gradient boosted trees model that can be used for prediction | ||
| */ | ||
| def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
| val algo = boostingStrategy.treeStrategy.algo | ||
| algo match { | ||
| val fitGradientBoostingModel = algo match { | ||
| case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) | ||
| case Classification => | ||
| // Map labels to -1, +1 so binary classification can be treated as regression. | ||
|
|
@@ -69,6 +74,42 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) | |
| case _ => | ||
| throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") | ||
| } | ||
| baseLearners = fitGradientBoostingModel.trees | ||
| baseLearnerWeights = fitGradientBoostingModel.treeWeights | ||
| fitGradientBoostingModel | ||
| } | ||
|
|
||
| /** | ||
| * Method to compute error or loss for every iteration of gradient boosting. | ||
| * @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] | ||
| * @param loss: evaluation metric that defaults to boostingStrategy.loss | ||
| * @return an array with index i having the losses or errors for the ensemble | ||
| * containing trees 1 to i + 1 | ||
| */ | ||
| def evaluateEachIteration( | ||
| data: RDD[LabeledPoint], | ||
| loss: Loss = boostingStrategy.loss) : Array[Double] = { | ||
|
|
||
| val algo = boostingStrategy.treeStrategy.algo | ||
| val remappedData = algo match { | ||
| case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
| case _ => data | ||
| } | ||
| val initialTree = baseLearners(0) | ||
| val evaluationArray = Array.fill(numIterations)(0.0) | ||
|
|
||
| // Initial weight is 1.0 | ||
| var predictionRDD = remappedData.map(i => initialTree.predict(i.features)) | ||
| evaluationArray(0) = loss.computeError(remappedData, predictionRDD) | ||
|
|
||
| (1 until numIterations).map {nTree => | ||
|
||
| predictionRDD = (remappedData zip predictionRDD) map { | ||
| case (point, pred) => | ||
| pred + baseLearners(nTree).predict(point.features) * baseLearnerWeights(nTree) | ||
| } | ||
| evaluationArray(nTree) = loss.computeError(remappedData, predictionRDD) | ||
| } | ||
| evaluationArray | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,4 +61,23 @@ object AbsoluteError extends Loss { | |
| math.abs(err) | ||
| }.mean() | ||
| } | ||
|
|
||
| /** | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for doc; it will be inherited from the overridden method (here and in other 2 loss classes)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the doc for the return is different, no?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's OK for the doc for gradient() and computeError() to be generic as long as the doc for the loss classes describes the specific loss function.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok so should I remove it?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes please |
||
| * Method to calculate loss when the predictions are already known. | ||
| * Note: This method is used in the method evaluateEachIteration to avoid recomputing the | ||
| * predicted values from previously fit trees. | ||
| * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
| * @param prediction: RDD[Double] of predicted labels. | ||
| * @return Mean absolute error of model on data | ||
| */ | ||
| override def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]): Double = { | ||
| val errorAcrossSamples = (data zip prediction) map { | ||
| case (yTrue, yPred) => { | ||
| val err = yTrue.label - yPred | ||
| math.abs(err) | ||
| } | ||
| } | ||
| errorAcrossSamples.mean() | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,4 +49,14 @@ trait Loss extends Serializable { | |
| */ | ||
| def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provide default implementation using other computeError, and then remove overridden copies from child classes
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, It gives mental satisfaction on removing huge blocks of code :P
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. : ) |
||
|
|
||
| /** | ||
| * Method to calculate loss when the predictions are already known. | ||
| * Note: This method is used in the method evaluateEachIteration to avoid recomputing the | ||
| * predicted values from previously fit trees. | ||
| * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
| * @param prediction: RDD[Double] of predicted labels. | ||
| * @return Measure of model error on data | ||
| */ | ||
| def computeError(data: RDD[LabeledPoint], prediction: RDD[Double]) : Double | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,10 +175,12 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { | |
| new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) | ||
| val gbtValidate = new GradientBoostedTrees(boostingStrategy) | ||
| .runWithValidation(trainRdd, validateRdd) | ||
| assert(gbtValidate.numTrees !== numIterations) | ||
| val numTrees = gbtValidate.numTrees | ||
| assert(numTrees !== numIterations) | ||
|
|
||
| // Test that it performs better on the validation dataset. | ||
| val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) | ||
| val gbtModel = new GradientBoostedTrees(boostingStrategy) | ||
| val gbt = gbtModel.run(trainRdd) | ||
| val (errorWithoutValidation, errorWithValidation) = { | ||
| if (algo == Classification) { | ||
| val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) | ||
|
|
@@ -188,6 +190,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { | |
| } | ||
| } | ||
| assert(errorWithValidation <= errorWithoutValidation) | ||
|
|
||
| // Test that results from evaluateEachIteration comply with runWithValidation. | ||
| // Note that convergenceTol is set to 0.0 | ||
| val evaluationArray = gbtModel.evaluateEachIteration(validateRdd) | ||
| assert(evaluationArray.length === numIterations) | ||
| assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) | ||
| var i = 1 | ||
| while (i < numTrees) { | ||
| assert(evaluationArray(i) < evaluationArray(i - 1)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. small issue: < should be <= |
||
| i += 1 | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method should be implemented in the model, not in the estimator. There's no need to make a duplicate of the model in the estimator class. (We try to keep estimator classes stateless except for parameter values so that they remain lightweight types.)
This change will require a bit of refactoring, so I'll hold off on more comments until then.