-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-6025] [MLlib] Add helper method evaluateEachIteration to extract learning curve #4906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext} | |
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.api.java.JavaRDD | ||
| import org.apache.spark.mllib.linalg.Vector | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.configuration.Algo | ||
| import org.apache.spark.mllib.tree.configuration.Algo._ | ||
| import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ | ||
| import org.apache.spark.mllib.tree.loss.Loss | ||
| import org.apache.spark.mllib.util.{Loader, Saveable} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.SQLContext | ||
|
|
@@ -108,6 +110,58 @@ class GradientBoostedTreesModel( | |
| } | ||
|
|
||
| override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion | ||
|
|
||
| /** | ||
| * Method to compute error or loss for every iteration of gradient boosting. | ||
| * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] | ||
| * @param loss evaluation metric. | ||
| * @return an array with index i having the losses or errors for the ensemble | ||
| * containing trees 1 to i + 1 | ||
| */ | ||
| def evaluateEachIteration( | ||
| data: RDD[LabeledPoint], | ||
| loss: Loss): Array[Double] = { | ||
|
|
||
| val sc = data.sparkContext | ||
| val remappedData = algo match { | ||
| case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
| case _ => data | ||
| } | ||
|
|
||
| val numIterations = trees.length | ||
| val evaluationArray = Array.fill(numIterations)(0.0) | ||
|
|
||
| var predictionAndError: RDD[(Double, Double)] = remappedData.map { i => | ||
| val pred = treeWeights(0) * trees(0).predict(i.features) | ||
| val error = loss.computeError(pred, i.label) | ||
| (pred, error) | ||
| } | ||
| evaluationArray(0) = predictionAndError.values.mean() | ||
|
|
||
| // Avoid the model being copied across numIterations. | ||
| val broadcastTrees = sc.broadcast(trees) | ||
| val broadcastWeights = sc.broadcast(treeWeights) | ||
|
|
||
| (1 until numIterations).map { nTree => | ||
| predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => | ||
| val currentTree = broadcastTrees.value(nTree) | ||
| val currentTreeWeight = broadcastWeights.value(nTree) | ||
| iter.map { | ||
| case (point, (pred, error)) => { | ||
| val newPred = pred + currentTree.predict(point.features) * currentTreeWeight | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized: This is correct for regression but not for classification. For classification, it should threshold as in [https://github.com/apache/spark/blob/e3f315ac358dfe4f5b9705c3eac76e8b1e24f82a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala#L194] It's also a problem that the test suite didn't find this error. Could you please first fix the test suite so that it fails because of this error and then fix it here? Thanks! Sorry I didn't realize it before.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is more of a design problem. Do we want
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And also the fact that runWithValidation breaks according to the Regression loss and not the Classification loss. This might lead to different solutions when runWithValidation and evaluateEachIteration is used. I suggest we keep this as it is and maybe add a comment?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right; I was getting confused. It's correct to use the raw prediction for classification, as you're doing. |
||
| val newError = loss.computeError(newPred, point.label) | ||
| (newPred, newError) | ||
| } | ||
| } | ||
| } | ||
| evaluationArray(nTree) = predictionAndError.values.mean() | ||
| } | ||
|
|
||
| broadcastTrees.unpersist() | ||
| broadcastWeights.unpersist() | ||
| evaluationArray | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to explicitly unpersist the broadcast values before returning. They will get unpersisted once their values go out of scope, but it might take longer. |
||
| } | ||
|
|
||
| } | ||
|
|
||
| object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: Use 0-based indexing for doc: "containing trees 0 to i"
Or just say "containing the first i+1 trees"