-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-14709][ML] spark.ml API for linear SVM #15211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@yanboliang this is a quick prototype for the ML SVM. I'll make another pass tomorrow to refine the code and add more unit tests. It mimics many behavior as LR, yet I'm not sure if we need to achieve similar complexity as LR . |
|
Test build #65816 has finished for PR 15211 at commit
|
|
Test build #65898 has finished for PR 15211 at commit
|
| */ | ||
| @Since("2.1.0") | ||
| @Experimental | ||
| class SVM @Since("2.1.0") ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about SVMClassifier? we can also train regression model with SVM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to LinearSVC, so we can have other SVM Classifier in the future.
| @Experimental | ||
| class SVM @Since("2.1.0") ( | ||
| @Since("2.1.0") override val uid: String) | ||
| extends Predictor[Vector, SVM, SVMModel] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Under the framework Classifier or ProbabilisticClassifier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks I changed it to Classifier. AFAIK, SVM raw result may not be used to indicate the probability.
| override protected def train(dataset: Dataset[_]): SVMModel = { | ||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
| val instances: RDD[Instance] = | ||
| dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
labelCol is now not need to be casted to DoubleType, because it is casted in Predictor.fit()
see #15414
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks.
| with SVMParams with DefaultParamsWritable { | ||
|
|
||
| @Since("2.1.0") | ||
| def this() = this(Identifiable.randomUID("svm")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we rename this to SVMClassifier, the uid should be svc
| } | ||
|
|
||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(regParam, standardization, threshold, maxIter, tol, fitIntercept) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep in line with other algos, labelCol, weightCol, featuresCol should be added here
| n | ||
| case None => histogram.length | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
require(numClasses == 2, "...") ?
| val localCoefficientsArray = coefficientsArray | ||
| val localGradientSumArray = gradientSumArray | ||
|
|
||
| numClasses match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about moving checking of numClasses outside of add()?
|
Thanks for suggestions @yanboliang @zhengruifeng. I'm thinking to rename to class to LinearSVMClassifier. Just as other linearSVM implementation, this will only be a binary classifier and multi-classification will be supported via one-vs-rest. |
|
Test build #70155 has finished for PR 15211 at commit
|
|
Sent an update to address some comments. |
|
Test build #70232 has finished for PR 15211 at commit
|
|
Remove WIP. This is ready for review. Thanks. |
|
Test build #71105 has finished for PR 15211 at commit
|
|
Sent an update to include a R unit test. Yet I met a problem that there's a constant scaling difference between LinearSVC and R 1071 (which essentially is LibSVM). It's possible that it's caused by some parameter setting. Post it anyway to see if there's any suggestions. Sorry @zhengruifeng, I'll address your comment in the next update. |
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For comparing with R, I'm wondering if the main issue is that it's hard to calculate the appropriate C given a regParam setting. Would it be easier to use this R package instead? https://cran.r-project.org/web/packages/svmpath/
Also, the test with sample weights takes 40 seconds. Does it still pass if you increase the 'tol' Param to make the test faster?
| Use the following R code to load the data and train the model using glmnet package. | ||
| library(e1071) | ||
| data <- read.csv("/home/yuhao/workspace/github/hhbyyh/Test/SVM/svm/part-00000", header=FALSE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about basing the data location at target/tmp/LinearSVC/binaryDataset to match the data export?
| */ | ||
| val coefficientsR = Vectors.dense(-7.310475, -14.89742, -22.21019, -29.83495) | ||
| val interceptR = -7.440296 | ||
| assert(model1.intercept / interceptR ~== -0.9 relTol 2E-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a strange way to write the comparison. Was this a temporary thing to make the tests pass?
|
Thanks @jkbradley. I reviewed the gradient and loss function. The corresponding regularization lambda should satisfy lambda * c * N (data size) = 2. I corrected the unit test against R 1071 and added another unit test against sklearn. |
|
Test build #71618 has finished for PR 15211 at commit
|
|
Test build #71619 has finished for PR 15211 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hhbyyh This looks great. Combining the tests to reduce test time is the only remaining issue, I believe.
| assert(model1.coefficients ~== coefficientsR relTol 1E-2) | ||
| } | ||
|
|
||
| test("linearSVC comparison with scikit-learn") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's combine these to avoid retraining since this is the same as the R test. (And training takes 35 sec)
| features <- as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) | ||
| svm_model <- svm(features, label, type='C', kernel='linear', cost=10, scale=F, tolerance=1e-4) | ||
| summary(svm_model) | ||
| w <- -t(svm_model$coefs) %*% svm_model$SV |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove "-" to make values match lines coefficientsR, interceptR below
| > w | ||
| data.V2 data.V3 data.V4 data.V5 | ||
| [1,] -7.310475 -14.89742 -22.21019 -29.83495 | ||
| > -svm_model$rho |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"-" here may be necessary as b = -model$rho #Offset
|
|
||
| test("linearSVC comparison with scikit-learn") { | ||
| val trainer1 = new LinearSVC() | ||
| .setRegParam(0.00002) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment that this matches C above.
|
Test build #71620 has finished for PR 15211 at commit
|
|
I see, will work on combining the tests now. Also I'm thinking if we should consider using |
|
Test build #71623 has finished for PR 15211 at commit
|
|
Test build #71625 has finished for PR 15211 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to keep regParam. I think it's about as common in literature and practice as specifying the constraint C.
| > w | ||
| data.V2 data.V3 data.V4 data.V5 | ||
| [1,] -7.310338 -14.89741 -22.21005 -29.83508 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
w and the intercept are still negative, which isn't what we wanted, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I misunderstood your last comment. Changing them to positive now.
| [1] -7.440177 | ||
| */ | ||
| val coefficientsR = Vectors.dense(7.310475, 14.89742, 22.21019, 29.83495) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these values changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updated weights are the correct ones and they are stable. I forgot how to generate the original weights... maybe use a different random seed for data generation.
|
Test build #71742 has finished for PR 15211 at commit
|
|
LGTM One more step towards feature parity for the DataFrame-based API! |
|
I'll create follow-up JIRAs (linked from this PR's JIRA). @hhbyyh Can I assign one or more to you? |
|
Thanks @jkbradley for driving the review process. |
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-14709 Provide API for SVM algorithm for DataFrames. As discussed in jira, the initial implementation uses OWL-QN with Hinge loss function. The API should mimic existing spark.ml.classification APIs. Currently only Binary Classification is supported. Multinomial support can be added in this or following release. ## How was this patch tested? new unit tests and simple manual test Author: Yuhao <[email protected]> Author: Yuhao Yang <[email protected]> Closes apache#15211 from hhbyyh/mlsvm.
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-14709 Provide API for SVM algorithm for DataFrames. As discussed in jira, the initial implementation uses OWL-QN with Hinge loss function. The API should mimic existing spark.ml.classification APIs. Currently only Binary Classification is supported. Multinomial support can be added in this or following release. ## How was this patch tested? new unit tests and simple manual test Author: Yuhao <[email protected]> Author: Yuhao Yang <[email protected]> Closes apache#15211 from hhbyyh/mlsvm.
What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-14709
Provide API for SVM algorithm for DataFrames. As discussed in jira, the initial implementation uses OWL-QN with Hinge loss function.
The API should mimic existing spark.ml.classification APIs.
Currently only Binary Classification is supported. Multinomial support can be added in this or following release.
How was this patch tested?
new unit tests and simple manual test