Skip to content

Conversation

@hhbyyh
Copy link
Contributor

@hhbyyh hhbyyh commented Sep 23, 2016

What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-14709

Provide API for SVM algorithm for DataFrames. As discussed in jira, the initial implementation uses OWL-QN with Hinge loss function.
The API should mimic existing spark.ml.classification APIs.
Currently only Binary Classification is supported. Multinomial support can be added in this or following release.

How was this patch tested?

new unit tests and simple manual test

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Sep 23, 2016

@yanboliang this is a quick prototype for the ML SVM. I'll make another pass tomorrow to refine the code and add more unit tests. It mimics many behavior as LR, yet I'm not sure if we need to achieve similar complexity as LR .
For reviewers, appreciate if you can take a glance just in case there's something out of expectation. We can work on details after I remove the WIP tag.

@SparkQA
Copy link

SparkQA commented Sep 23, 2016

Test build #65816 has finished for PR 15211 at commit f8ddc3b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class SVM @Since(\"2.1.0\") (
    • class SVMModelWriter(instance: SVMModel)

@SparkQA
Copy link

SparkQA commented Sep 26, 2016

Test build #65898 has finished for PR 15211 at commit 73b8011.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

*/
@Since("2.1.0")
@Experimental
class SVM @Since("2.1.0") (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about SVMClassifier? we can also train regression model with SVM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to LinearSVC, so we can have other SVM Classifier in the future.

@Experimental
class SVM @Since("2.1.0") (
@Since("2.1.0") override val uid: String)
extends Predictor[Vector, SVM, SVMModel]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the framework Classifier or ProbabilisticClassifier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I changed it to Classifier. AFAIK, SVM raw result may not be used to indicate the probability.

override protected def train(dataset: Dataset[_]): SVMModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

labelCol is now not need to be casted to DoubleType, because it is casted in Predictor.fit()
see #15414

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks.

with SVMParams with DefaultParamsWritable {

@Since("2.1.0")
def this() = this(Identifiable.randomUID("svm"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we rename this to SVMClassifier, the uid should be svc

}

val instr = Instrumentation.create(this, instances)
instr.logParams(regParam, standardization, threshold, maxIter, tol, fitIntercept)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep in line with other algos, labelCol, weightCol, featuresCol should be added here

n
case None => histogram.length
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

require(numClasses == 2, "...") ?

val localCoefficientsArray = coefficientsArray
val localGradientSumArray = gradientSumArray

numClasses match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about moving checking of numClasses outside of add()?

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Dec 10, 2016

Thanks for suggestions @yanboliang @zhengruifeng. I'm thinking to rename to class to LinearSVMClassifier. Just as other linearSVM implementation, this will only be a binary classifier and multi-classification will be supported via one-vs-rest.

@SparkQA
Copy link

SparkQA commented Dec 15, 2016

Test build #70155 has finished for PR 15211 at commit 4902517.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Dec 15, 2016

Sent an update to address some comments.

@SparkQA
Copy link

SparkQA commented Dec 16, 2016

Test build #70232 has finished for PR 15211 at commit c8a7553.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hhbyyh hhbyyh changed the title [SPARK-14709][ML] [WIP] spark.ml API for linear SVM [SPARK-14709][ML] spark.ml API for linear SVM Dec 16, 2016
@hhbyyh
Copy link
Contributor Author

hhbyyh commented Dec 16, 2016

Remove WIP. This is ready for review. Thanks.

@SparkQA
Copy link

SparkQA commented Jan 10, 2017

Test build #71105 has finished for PR 15211 at commit 05fbd02.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Jan 10, 2017

Sent an update to include a R unit test. Yet I met a problem that there's a constant scaling difference between LinearSVC and R 1071 (which essentially is LibSVM). It's possible that it's caused by some parameter setting. Post it anyway to see if there's any suggestions.

Sorry @zhengruifeng, I'll address your comment in the next update.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For comparing with R, I'm wondering if the main issue is that it's hard to calculate the appropriate C given a regParam setting. Would it be easier to use this R package instead? https://cran.r-project.org/web/packages/svmpath/

Also, the test with sample weights takes 40 seconds. Does it still pass if you increase the 'tol' Param to make the test faster?

Use the following R code to load the data and train the model using glmnet package.
library(e1071)
data <- read.csv("/home/yuhao/workspace/github/hhbyyh/Test/SVM/svm/part-00000", header=FALSE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about basing the data location at target/tmp/LinearSVC/binaryDataset to match the data export?

*/
val coefficientsR = Vectors.dense(-7.310475, -14.89742, -22.21019, -29.83495)
val interceptR = -7.440296
assert(model1.intercept / interceptR ~== -0.9 relTol 2E-2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a strange way to write the comparison. Was this a temporary thing to make the tests pass?

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Jan 18, 2017

Thanks @jkbradley.

I reviewed the gradient and loss function. The corresponding regularization lambda should satisfy lambda * c * N (data size) = 2.
c equals to 10 by default in R and sklearn. thus we should set reg = 0.00002 (2 / 10000 / 10).

I corrected the unit test against R 1071 and added another unit test against sklearn.

@SparkQA
Copy link

SparkQA commented Jan 18, 2017

Test build #71618 has finished for PR 15211 at commit 36f585c.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 18, 2017

Test build #71619 has finished for PR 15211 at commit 72719fc.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hhbyyh This looks great. Combining the tests to reduce test time is the only remaining issue, I believe.

assert(model1.coefficients ~== coefficientsR relTol 1E-2)
}

test("linearSVC comparison with scikit-learn") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's combine these to avoid retraining since this is the same as the R test. (And training takes 35 sec)

features <- as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
svm_model <- svm(features, label, type='C', kernel='linear', cost=10, scale=F, tolerance=1e-4)
summary(svm_model)
w <- -t(svm_model$coefs) %*% svm_model$SV
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove "-" to make values match lines coefficientsR, interceptR below

> w
data.V2 data.V3 data.V4 data.V5
[1,] -7.310475 -14.89742 -22.21019 -29.83495
> -svm_model$rho
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"-" here may be necessary as b = -model$rho #Offset


test("linearSVC comparison with scikit-learn") {
val trainer1 = new LinearSVC()
.setRegParam(0.00002)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment that this matches C above.

@SparkQA
Copy link

SparkQA commented Jan 18, 2017

Test build #71620 has finished for PR 15211 at commit 2e99a0f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Jan 18, 2017

I see, will work on combining the tests now. Also I'm thinking if we should consider using c (cost) to replace RegParam in `LinearSVC' to be more friendly for SVM users. Yet the change may be confusing for Spark users. I'm neutral on this.

@SparkQA
Copy link

SparkQA commented Jan 19, 2017

Test build #71623 has finished for PR 15211 at commit d62c107.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 19, 2017

Test build #71625 has finished for PR 15211 at commit 9b71a3a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to keep regParam. I think it's about as common in literature and practice as specifying the constraint C.

> w
data.V2 data.V3 data.V4 data.V5
[1,] -7.310338 -14.89741 -22.21005 -29.83508
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

w and the intercept are still negative, which isn't what we wanted, right?

Copy link
Contributor Author

@hhbyyh hhbyyh Jan 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misunderstood your last comment. Changing them to positive now.

[1] -7.440177
*/
val coefficientsR = Vectors.dense(7.310475, 14.89742, 22.21019, 29.83495)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these values changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated weights are the correct ones and they are stable. I forgot how to generate the original weights... maybe use a different random seed for data generation.

@SparkQA
Copy link

SparkQA commented Jan 20, 2017

Test build #71742 has finished for PR 15211 at commit bbcb7cb.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

LGTM
Thanks @hhbyyh and also @yanboliang and @zhengruifeng for helping with review!
Merging with master

One more step towards feature parity for the DataFrame-based API!

@jkbradley
Copy link
Member

I'll create follow-up JIRAs (linked from this PR's JIRA). @hhbyyh Can I assign one or more to you?

@asfgit asfgit closed this in 4a11d02 Jan 23, 2017
@hhbyyh
Copy link
Contributor Author

hhbyyh commented Jan 23, 2017

Thanks @jkbradley for driving the review process.
Also thanks @yanboliang and @zhengruifeng for the helpful comments.
Sure I'd like to keep working on the follow-up tasks.

uzadude pushed a commit to uzadude/spark that referenced this pull request Jan 27, 2017
## What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-14709

Provide API for SVM algorithm for DataFrames. As discussed in jira, the initial implementation uses OWL-QN with Hinge loss function.
The API should mimic existing spark.ml.classification APIs.
Currently only Binary Classification is supported. Multinomial support can be added in this or following release.
## How was this patch tested?

new unit tests and simple manual test

Author: Yuhao <[email protected]>
Author: Yuhao Yang <[email protected]>

Closes apache#15211 from hhbyyh/mlsvm.
cmonkey pushed a commit to cmonkey/spark that referenced this pull request Feb 15, 2017
## What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-14709

Provide API for SVM algorithm for DataFrames. As discussed in jira, the initial implementation uses OWL-QN with Hinge loss function.
The API should mimic existing spark.ml.classification APIs.
Currently only Binary Classification is supported. Multinomial support can be added in this or following release.
## How was this patch tested?

new unit tests and simple manual test

Author: Yuhao <[email protected]>
Author: Yuhao Yang <[email protected]>

Closes apache#15211 from hhbyyh/mlsvm.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants