Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.ml.stat._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql._
import org.apache.spark.storage.StorageLevel

/** Params for linear SVM Classifier. */
Expand Down Expand Up @@ -267,7 +267,26 @@ class LinearSVC @Since("2.2.0") (
if (featuresStd(i) != 0.0) rawCoefficients(i) / featuresStd(i) else 0.0
}
val intercept = if ($(fitIntercept)) rawCoefficients.last else 0.0
copyValues(new LinearSVCModel(uid, Vectors.dense(coefficientArray), intercept))
createModel(dataset, Vectors.dense(coefficientArray), intercept, objectiveHistory)
}

private def createModel(
dataset: Dataset[_],
coefficients: Vector,
intercept: Double,
objectiveHistory: Array[Double]): LinearSVCModel = {
val model = copyValues(new LinearSVCModel(uid, coefficients, intercept))
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel()
val summary = new LinearSVCTrainingSummaryImpl(
summaryModel.transform(dataset),
rawPredictionColName,
predictionColName,
$(labelCol),
weightColName,
objectiveHistory)
model.setSummary(Some(summary))
}

private def trainOnRows(
Expand Down Expand Up @@ -352,7 +371,7 @@ class LinearSVCModel private[classification] (
@Since("2.2.0") val coefficients: Vector,
@Since("2.2.0") val intercept: Double)
extends ClassificationModel[Vector, LinearSVCModel]
with LinearSVCParams with MLWritable {
with LinearSVCParams with MLWritable with HasTrainingSummary[LinearSVCTrainingSummary] {

@Since("2.2.0")
override val numClasses: Int = 2
Expand All @@ -368,6 +387,48 @@ class LinearSVCModel private[classification] (
BLAS.dot(features, coefficients) + intercept
}

/**
* Gets summary of model on training set. An exception is thrown
* if `hasSummary` is false.
*/
@Since("3.1.0")
override def summary: LinearSVCTrainingSummary = super.summary

/**
* If the rawPrediction and prediction columns are set, this method returns the current model,
* otherwise it generates new columns for them and sets them as columns on a new copy of
* the current model
*/
private[classification] def findSummaryModel(): (LinearSVCModel, String, String) = {
val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
copy(ParamMap.empty)
.setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
.setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else if ($(rawPredictionCol).isEmpty) {
copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
java.util.UUID.randomUUID.toString)
} else if ($(predictionCol).isEmpty) {
copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
} else {
this
}
(model, model.getRawPredictionCol, model.getPredictionCol)
}

/**
* Evaluates the model on a test dataset.
*
* @param dataset Test dataset to evaluate model on.
*/
@Since("3.1.0")
def evaluate(dataset: Dataset[_]): LinearSVCSummary = {
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
// Handle possible missing or invalid rawPrediction or prediction columns
val (summaryModel, rawPrediction, predictionColName) = findSummaryModel()
new LinearSVCSummaryImpl(summaryModel.transform(dataset),
rawPrediction, predictionColName, $(labelCol), weightColName)
}

override def predict(features: Vector): Double = {
if (margin(features) > $(threshold)) 1.0 else 0.0
}
Expand Down Expand Up @@ -439,3 +500,53 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
}
}
}

/**
* Abstraction for LinearSVC results for a given model.
*/
sealed trait LinearSVCSummary extends BinaryClassificationSummary

/**
* Abstraction for LinearSVC training results.
*/
sealed trait LinearSVCTrainingSummary extends LinearSVCSummary with TrainingSummary

/**
* LinearSVC results for a given model.
*
* @param predictions dataframe output by the model's `transform` method.
* @param scoreCol field in "predictions" which gives the rawPrediction of each instance.
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param weightCol field in "predictions" which gives the weight of each instance.
*/
private class LinearSVCSummaryImpl(
@transient override val predictions: DataFrame,
override val scoreCol: String,
override val predictionCol: String,
override val labelCol: String,
override val weightCol: String)
extends LinearSVCSummary

/**
* LinearSVC training results.
*
* @param predictions dataframe output by the model's `transform` method.
* @param scoreCol field in "predictions" which gives the rawPrediction of each instance.
* @param predictionCol field in "predictions" which gives the prediction for a data instance as a
* double.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param weightCol field in "predictions" which gives the weight of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
private class LinearSVCTrainingSummaryImpl(
predictions: DataFrame,
scoreCol: String,
predictionCol: String,
labelCol: String,
weightCol: String,
override val objectiveHistory: Array[Double])
extends LinearSVCSummaryImpl(
predictions, scoreCol, predictionCol, labelCol, weightCol)
with LinearSVCTrainingSummary
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.functions._


class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
Expand Down Expand Up @@ -284,6 +284,57 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
assert(model1.coefficients ~== coefficientsSK relTol 4E-3)
}

test("summary and training summary") {
val lsvc = new LinearSVC()
val model = lsvc.setMaxIter(5).fit(smallBinaryDataset)

val summary = model.evaluate(smallBinaryDataset)

assert(model.summary.accuracy === summary.accuracy)
assert(model.summary.weightedPrecision === summary.weightedPrecision)
assert(model.summary.weightedRecall === summary.weightedRecall)
assert(model.summary.pr.collect() === summary.pr.collect())
assert(model.summary.roc.collect() === summary.roc.collect())
assert(model.summary.areaUnderROC === summary.areaUnderROC)

// verify instance weight works
val lsvc2 = new LinearSVC()
.setMaxIter(5)
.setWeightCol("weight")

val smallBinaryDatasetWithWeight =
smallBinaryDataset.select(col("label"), col("features"), lit(2.5).as("weight"))

val summary2 = model.evaluate(smallBinaryDatasetWithWeight)

val model2 = lsvc2.fit(smallBinaryDatasetWithWeight)
assert(model2.summary.accuracy === summary2.accuracy)
assert(model2.summary.weightedPrecision ~== summary2.weightedPrecision relTol 1e-6)
assert(model2.summary.weightedRecall === summary2.weightedRecall)
assert(model2.summary.pr.collect() === summary2.pr.collect())
assert(model2.summary.roc.collect() === summary2.roc.collect())
assert(model2.summary.areaUnderROC === summary2.areaUnderROC)

assert(model2.summary.accuracy === model.summary.accuracy)
assert(model2.summary.weightedPrecision ~== model.summary.weightedPrecision relTol 1e-6)
assert(model2.summary.weightedRecall === model.summary.weightedRecall)
assert(model2.summary.pr.collect() === model.summary.pr.collect())
assert(model2.summary.roc.collect() === model.summary.roc.collect())
assert(model2.summary.areaUnderROC === model.summary.areaUnderROC)
}

test("linearSVC training summary totalIterations") {
Seq(1, 5, 10, 20, 100).foreach { maxIter =>
val trainer = new LinearSVC().setMaxIter(maxIter)
val model = trainer.fit(smallBinaryDataset)
if (maxIter == 1) {
assert(model.summary.totalIterations === maxIter)
} else {
assert(model.summary.totalIterations <= maxIter)
}
}
}

test("read/write: SVM") {
def checkModelData(model: LinearSVCModel, model2: LinearSVCModel): Unit = {
assert(model.intercept === model2.intercept)
Expand Down
48 changes: 47 additions & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pyspark.storagelevel import StorageLevel

__all__ = ['LinearSVC', 'LinearSVCModel',
'LinearSVCSummary', 'LinearSVCTrainingSummary',
'LogisticRegression', 'LogisticRegressionModel',
'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
Expand Down Expand Up @@ -683,7 +684,8 @@ def setBlockSize(self, value):
return self._set(blockSize=value)


class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable):
class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable,
HasTrainingSummary):
"""
Model fitted by LinearSVC.

Expand Down Expand Up @@ -713,6 +715,50 @@ def intercept(self):
"""
return self._call_java("intercept")

@since("3.1.0")
def summary(self):
"""
Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
return LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)

@since("3.1.0")
def evaluate(self, dataset):
"""
Evaluates the model on a test dataset.

:param dataset:
Test dataset to evaluate model on, where dataset is an
instance of :py:class:`pyspark.sql.DataFrame`
"""
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_lsvc_summary = self._call_java("evaluate", dataset)
return LinearSVCSummary(java_lsvc_summary)


class LinearSVCSummary(_BinaryClassificationSummary):
"""
Abstraction for LinearSVC Results for a given model.
.. versionadded:: 3.1.0
"""
pass


@inherit_doc
class LinearSVCTrainingSummary(LinearSVCSummary, _TrainingSummary):
"""
Abstraction for LinearSVC Training results.

.. versionadded:: 3.1.0
"""
pass


class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam,
HasElasticNetParam, HasMaxIter, HasFitIntercept, HasTol,
Expand Down
46 changes: 44 additions & 2 deletions python/pyspark/ml/tests/test_training_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
if sys.version > '3':
basestring = str

from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \
LogisticRegressionSummary
from pyspark.ml.classification import BinaryLogisticRegressionSummary, LinearSVC, \
LinearSVCSummary, LogisticRegression, LogisticRegressionSummary
from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
Expand Down Expand Up @@ -193,6 +193,48 @@ def test_multiclass_logistic_regression_summary(self):
self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)

def test_linear_svc_summary(self):
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0, 1.0, 1.0)),
(0.0, 2.0, Vectors.dense(1.0, 2.0, 3.0))],
["label", "weight", "features"])
svc = LinearSVC(maxIter=5, weightCol="weight")
model = svc.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary()
# test that api is callable and returns expected types
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.scoreCol, "rawPrediction")
self.assertEqual(s.labelCol, "label")
self.assertEqual(s.predictionCol, "prediction")
objHist = s.objectiveHistory
self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
self.assertGreater(s.totalIterations, 0)
self.assertTrue(isinstance(s.labels, list))
self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
self.assertTrue(isinstance(s.precisionByLabel, list))
self.assertTrue(isinstance(s.recallByLabel, list))
self.assertTrue(isinstance(s.fMeasureByLabel(), list))
self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
self.assertTrue(isinstance(s.roc, DataFrame))
self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
self.assertTrue(isinstance(s.pr, DataFrame))
self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
print(s.weightedTruePositiveRate)
self.assertAlmostEqual(s.weightedTruePositiveRate, 0.5, 2)
self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.5, 2)
self.assertAlmostEqual(s.weightedRecall, 0.5, 2)
self.assertAlmostEqual(s.weightedPrecision, 0.25, 2)
self.assertAlmostEqual(s.weightedFMeasure(), 0.3333333333333333, 2)
self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.3333333333333333, 2)
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertTrue(isinstance(sameSummary, LinearSVCSummary))
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_gaussian_mixture_summary(self):
data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
(Vectors.sparse(1, [], []),)]
Expand Down