-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-6025] [MLlib] Add helper method evaluateEachIteration to extract learning curve #4906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,4 +49,14 @@ trait Loss extends Serializable { | |
| */ | ||
| def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provide default implementation using other computeError, and then remove overridden copies from child classes
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, It gives mental satisfaction on removing huge blocks of code :P
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. : ) |
||
|
|
||
| /** | ||
| * Method to calculate loss when the predictions are already known. | ||
| * Note: This method is used in the method evaluateEachIteration to avoid recomputing the | ||
| * predicted values from previously fit trees. | ||
| * @param datum: LabeledPoint | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no colon after param name (here and elsewhere) |
||
| * @param prediction: Predicted label. | ||
| * @return Measure of model error on datapoint. | ||
| */ | ||
| def computeError(datum: LabeledPoint, prediction: Double) : Double | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. switch arg order to match batch computeError more closely no space before colon (here and elsewhere) |
||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext} | |
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.api.java.JavaRDD | ||
| import org.apache.spark.mllib.linalg.Vector | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.configuration.Algo | ||
| import org.apache.spark.mllib.tree.configuration.Algo._ | ||
| import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ | ||
| import org.apache.spark.mllib.tree.loss.Loss | ||
| import org.apache.spark.mllib.util.{Loader, Saveable} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.SQLContext | ||
|
|
@@ -108,6 +110,53 @@ class GradientBoostedTreesModel( | |
| } | ||
|
|
||
| override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion | ||
|
|
||
| /** | ||
| * Method to compute error or loss for every iteration of gradient boosting. | ||
| * @param data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] | ||
| * @param loss: evaluation metric. | ||
| * @return an array with index i having the losses or errors for the ensemble | ||
| * containing trees 1 to i + 1 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: Use 0-based indexing for doc: "containing trees 0 to i" |
||
| */ | ||
| def evaluateEachIteration( | ||
| data: RDD[LabeledPoint], | ||
| loss: Loss) : Array[Double] = { | ||
|
|
||
| val sc = data.sparkContext | ||
| val remappedData = algo match { | ||
| case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
| case _ => data | ||
| } | ||
| val initialTree = trees(0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove? only used once |
||
| val numIterations = trees.length | ||
| val evaluationArray = Array.fill(numIterations)(0.0) | ||
|
|
||
| // Initial weight is 1.0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. may as well use initial weight explicitly in case that changes for some reason in the future |
||
| var predictionErrorModel = remappedData.map {i => | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. predictionErrorModel is an odd name (model?). I'd rename it to predictionAndError and possibly add an explicit type for clarity. |
||
| val pred = initialTree.predict(i.features) | ||
| val error = loss.computeError(i, pred) | ||
| (pred, error) | ||
| } | ||
| evaluationArray(0) = predictionErrorModel.values.mean() | ||
|
|
||
| // Avoid the model being copied across numIterations. | ||
| val broadcastTrees = sc.broadcast(trees) | ||
| val broadcastWeights = sc.broadcast(treeWeights) | ||
|
|
||
| (1 until numIterations).map {nTree => | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space after { |
||
| predictionErrorModel = (remappedData zip predictionErrorModel) map { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would use mapPartitions. Before iterating over the partition elements, extract the trees and weights from the broadcast variables. I believe that reduces overhead a little. Also, try to avoid infix notation since non-Scala people may not be used to it: |
||
| case (point, (pred, error)) => { | ||
| val newPred = pred + ( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: parenthesis on next line |
||
| broadcastTrees.value(nTree).predict(point.features) * broadcastWeights.value(nTree)) | ||
| val newError = loss.computeError(point, newPred) | ||
| (newPred, newError) | ||
| } | ||
| } | ||
| evaluationArray(nTree) = predictionErrorModel.values.mean() | ||
| } | ||
| evaluationArray | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to explicitly unpersist the broadcast values before returning. They will get unpersisted once their values go out of scope, but it might take longer. |
||
| } | ||
|
|
||
| } | ||
|
|
||
| object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { | |
| new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) | ||
| val gbtValidate = new GradientBoostedTrees(boostingStrategy) | ||
| .runWithValidation(trainRdd, validateRdd) | ||
| assert(gbtValidate.numTrees !== numIterations) | ||
| val numTrees = gbtValidate.numTrees | ||
| assert(numTrees !== numIterations) | ||
|
|
||
| // Test that it performs better on the validation dataset. | ||
| val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) | ||
| val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) | ||
| val (errorWithoutValidation, errorWithValidation) = { | ||
| if (algo == Classification) { | ||
| val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) | ||
|
|
@@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { | |
| } | ||
| } | ||
| assert(errorWithValidation <= errorWithoutValidation) | ||
|
|
||
| // Test that results from evaluateEachIteration comply with runWithValidation. | ||
| // Note that convergenceTol is set to 0.0 | ||
| val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) | ||
| assert(evaluationArray.length === numIterations) | ||
| assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) | ||
| var i = 1 | ||
| while (i < numTrees) { | ||
| assert(evaluationArray(i) < evaluationArray(i - 1)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. small issue: < should be <= |
||
| i += 1 | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for doc; it will be inherited from the overridden method (here and in other 2 loss classes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the doc for the return is different, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's OK for the doc for gradient() and computeError() to be generic as long as the doc for the loss classes describes the specific loss function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok so should I remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please