-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit #15414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #66629 has finished for PR 15414 at commit
|
|
Test build #66635 has finished for PR 15414 at commit
|
|
Jenkins, test this please |
|
Test build #66637 has finished for PR 15414 at commit
|
|
Jenkins, retest this please |
|
Test build #66649 has finished for PR 15414 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe simplify it: dataset.schema("value").metadata
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
|
What do you think about adding a new suite class MockPredictor(override val uid: String)
extends Predictor[Vector, MockPredictor, MockPredictionModel] {
override def train(dataset: Dataset[_]): MockPredictionModel = {
require(dataset.schema("label").dataType == DoubleType)
new MockPredictionModel(uid)
}
override def copy(extra: ParamMap): MockPredictor = defaultCopy(extra)
}
class MockPredictionModel(override val uid: String)
extends PredictionModel[Vector, MockPredictionModel] {
override def predict(features: Vector): Double = 1.0
override def copy(extra: ParamMap): MockPredictionModel = defaultCopy(extra)
}Then we just have a test that calls |
|
Ok, I will create this Suite. |
6c2a8d0 to
6c61e73
Compare
|
Test build #66710 has finished for PR 15414 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move into companion object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just put this logic directly in fit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for? If the intent is to force getNumClasses to infer the number of classes, then you're no longer testing the not inferred case. Further, the point of this PR is to eliminate the need to do that since it is not a robust solution, IMO.
Also, I'd like to remove the dependence on TreeTests here (and genRegressionDF) and just explicitly set the attributes in the functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will revert this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we just cycle through the types here and call fit. I think it's a bit confusing the way it is now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will update this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
override def predict(features: Vector): Double = throw new NotImplementedError() We can do this for everything except train.
6c61e73 to
6ef17b7
Compare
|
Test build #66814 has finished for PR 15414 at commit
|
|
@sethah I have maken some modification according to the comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the copy methods to throw NotImplementedError
|
Thanks, I'll take a more detailed look in the next couple of days. Let's also wait and see if we can get @yanboliang or @jkbradley to give an opinion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need DefaultReadWriteTest
|
Test build #66872 has finished for PR 15414 at commit
|
|
Test build #66880 has finished for PR 15414 at commit
|
|
@jkbradley @yanboliang Could you please have a review of this? This PR unify usage of labelCol casting and fixs a bug described in [https://issues.apache.org/jira/browse/SPARK-17797] |
|
@jkbradley @yanboliang Just re-pinging for your opinions. |
|
Can you please document in Predictor that it accepts all NumericType labels? Other than that, this LGTM. Thanks! |
|
LGTM as well after adding @jkbradley's suggestion. |
7cb4510 to
810c973
Compare
|
@jkbradley @sethah I add a comment, thanks for reviews. |
|
Test build #67861 has finished for PR 15414 at commit
|
|
LGTM |
## What changes were proposed in this pull request? 1, move cast to `Predictor` 2, and then, remove unnecessary cast ## How was this patch tested? existing tests Author: Zheng RuiFeng <[email protected]> Closes apache#15414 from zhengruifeng/move_cast.
What changes were proposed in this pull request?
1, move cast to
Predictor2, and then, remove unnecessary cast
How was this patch tested?
existing tests