diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 217398c51b393..1659bbb1d34b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -36,7 +36,7 @@ import org.apache.spark.ml.stat._ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql._ import org.apache.spark.storage.StorageLevel /** Params for linear SVM Classifier. */ @@ -267,7 +267,26 @@ class LinearSVC @Since("2.2.0") ( if (featuresStd(i) != 0.0) rawCoefficients(i) / featuresStd(i) else 0.0 } val intercept = if ($(fitIntercept)) rawCoefficients.last else 0.0 - copyValues(new LinearSVCModel(uid, Vectors.dense(coefficientArray), intercept)) + createModel(dataset, Vectors.dense(coefficientArray), intercept, objectiveHistory) + } + + private def createModel( + dataset: Dataset[_], + coefficients: Vector, + intercept: Double, + objectiveHistory: Array[Double]): LinearSVCModel = { + val model = copyValues(new LinearSVCModel(uid, coefficients, intercept)) + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel() + val summary = new LinearSVCTrainingSummaryImpl( + summaryModel.transform(dataset), + rawPredictionColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + model.setSummary(Some(summary)) } private def trainOnRows( @@ -352,7 +371,7 @@ class LinearSVCModel private[classification] ( @Since("2.2.0") val coefficients: Vector, @Since("2.2.0") val intercept: Double) extends ClassificationModel[Vector, LinearSVCModel] - with LinearSVCParams with MLWritable { + with LinearSVCParams with MLWritable with HasTrainingSummary[LinearSVCTrainingSummary] { @Since("2.2.0") override val numClasses: Int = 2 @@ -368,6 +387,48 @@ class LinearSVCModel private[classification] ( BLAS.dot(features, coefficients) + intercept } + /** + * Gets summary of model on training set. An exception is thrown + * if `hasSummary` is false. + */ + @Since("3.1.0") + override def summary: LinearSVCTrainingSummary = super.summary + + /** + * If the rawPrediction and prediction columns are set, this method returns the current model, + * otherwise it generates new columns for them and sets them as columns on a new copy of + * the current model + */ + private[classification] def findSummaryModel(): (LinearSVCModel, String, String) = { + val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) { + copy(ParamMap.empty) + .setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString) + .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else if ($(rawPredictionCol).isEmpty) { + copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" + + java.util.UUID.randomUUID.toString) + } else if ($(predictionCol).isEmpty) { + copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else { + this + } + (model, model.getRawPredictionCol, model.getPredictionCol) + } + + /** + * Evaluates the model on a test dataset. + * + * @param dataset Test dataset to evaluate model on. + */ + @Since("3.1.0") + def evaluate(dataset: Dataset[_]): LinearSVCSummary = { + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + // Handle possible missing or invalid rawPrediction or prediction columns + val (summaryModel, rawPrediction, predictionColName) = findSummaryModel() + new LinearSVCSummaryImpl(summaryModel.transform(dataset), + rawPrediction, predictionColName, $(labelCol), weightColName) + } + override def predict(features: Vector): Double = { if (margin(features) > $(threshold)) 1.0 else 0.0 } @@ -439,3 +500,53 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { } } } + +/** + * Abstraction for LinearSVC results for a given model. + */ +sealed trait LinearSVCSummary extends BinaryClassificationSummary + +/** + * Abstraction for LinearSVC training results. + */ +sealed trait LinearSVCTrainingSummary extends LinearSVCSummary with TrainingSummary + +/** + * LinearSVC results for a given model. + * + * @param predictions dataframe output by the model's `transform` method. + * @param scoreCol field in "predictions" which gives the rawPrediction of each instance. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param weightCol field in "predictions" which gives the weight of each instance. + */ +private class LinearSVCSummaryImpl( + @transient override val predictions: DataFrame, + override val scoreCol: String, + override val predictionCol: String, + override val labelCol: String, + override val weightCol: String) + extends LinearSVCSummary + +/** + * LinearSVC training results. + * + * @param predictions dataframe output by the model's `transform` method. + * @param scoreCol field in "predictions" which gives the rawPrediction of each instance. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param weightCol field in "predictions" which gives the weight of each instance. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +private class LinearSVCTrainingSummaryImpl( + predictions: DataFrame, + scoreCol: String, + predictionCol: String, + labelCol: String, + weightCol: String, + override val objectiveHistory: Array[Double]) + extends LinearSVCSummaryImpl( + predictions, scoreCol, predictionCol, labelCol, weightCol) + with LinearSVCTrainingSummary diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 579d6b12ab99f..a66397324c1a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.functions._ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { @@ -284,6 +284,57 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { assert(model1.coefficients ~== coefficientsSK relTol 4E-3) } + test("summary and training summary") { + val lsvc = new LinearSVC() + val model = lsvc.setMaxIter(5).fit(smallBinaryDataset) + + val summary = model.evaluate(smallBinaryDataset) + + assert(model.summary.accuracy === summary.accuracy) + assert(model.summary.weightedPrecision === summary.weightedPrecision) + assert(model.summary.weightedRecall === summary.weightedRecall) + assert(model.summary.pr.collect() === summary.pr.collect()) + assert(model.summary.roc.collect() === summary.roc.collect()) + assert(model.summary.areaUnderROC === summary.areaUnderROC) + + // verify instance weight works + val lsvc2 = new LinearSVC() + .setMaxIter(5) + .setWeightCol("weight") + + val smallBinaryDatasetWithWeight = + smallBinaryDataset.select(col("label"), col("features"), lit(2.5).as("weight")) + + val summary2 = model.evaluate(smallBinaryDatasetWithWeight) + + val model2 = lsvc2.fit(smallBinaryDatasetWithWeight) + assert(model2.summary.accuracy === summary2.accuracy) + assert(model2.summary.weightedPrecision ~== summary2.weightedPrecision relTol 1e-6) + assert(model2.summary.weightedRecall === summary2.weightedRecall) + assert(model2.summary.pr.collect() === summary2.pr.collect()) + assert(model2.summary.roc.collect() === summary2.roc.collect()) + assert(model2.summary.areaUnderROC === summary2.areaUnderROC) + + assert(model2.summary.accuracy === model.summary.accuracy) + assert(model2.summary.weightedPrecision ~== model.summary.weightedPrecision relTol 1e-6) + assert(model2.summary.weightedRecall === model.summary.weightedRecall) + assert(model2.summary.pr.collect() === model.summary.pr.collect()) + assert(model2.summary.roc.collect() === model.summary.roc.collect()) + assert(model2.summary.areaUnderROC === model.summary.areaUnderROC) + } + + test("linearSVC training summary totalIterations") { + Seq(1, 5, 10, 20, 100).foreach { maxIter => + val trainer = new LinearSVC().setMaxIter(maxIter) + val model = trainer.fit(smallBinaryDataset) + if (maxIter == 1) { + assert(model.summary.totalIterations === maxIter) + } else { + assert(model.summary.totalIterations <= maxIter) + } + } + } + test("read/write: SVM") { def checkModelData(model: LinearSVCModel, model2: LinearSVCModel): Unit = { assert(model.intercept === model2.intercept) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ff506066519cd..bdd37c99df0a8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -39,6 +39,7 @@ from pyspark.storagelevel import StorageLevel __all__ = ['LinearSVC', 'LinearSVCModel', + 'LinearSVCSummary', 'LinearSVCTrainingSummary', 'LogisticRegression', 'LogisticRegressionModel', 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary', 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary', @@ -683,7 +684,8 @@ def setBlockSize(self, value): return self._set(blockSize=value) -class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable): +class LinearSVCModel(_JavaClassificationModel, _LinearSVCParams, JavaMLWritable, JavaMLReadable, + HasTrainingSummary): """ Model fitted by LinearSVC. @@ -713,6 +715,50 @@ def intercept(self): """ return self._call_java("intercept") + @since("3.1.0") + def summary(self): + """ + Gets summary (e.g. accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. + """ + if self.hasSummary: + return LinearSVCTrainingSummary(super(LinearSVCModel, self).summary) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + + @since("3.1.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_lsvc_summary = self._call_java("evaluate", dataset) + return LinearSVCSummary(java_lsvc_summary) + + +class LinearSVCSummary(_BinaryClassificationSummary): + """ + Abstraction for LinearSVC Results for a given model. + .. versionadded:: 3.1.0 + """ + pass + + +@inherit_doc +class LinearSVCTrainingSummary(LinearSVCSummary, _TrainingSummary): + """ + Abstraction for LinearSVC Training results. + + .. versionadded:: 3.1.0 + """ + pass + class _LogisticRegressionParams(_ProbabilisticClassifierParams, HasRegParam, HasElasticNetParam, HasMaxIter, HasFitIntercept, HasTol, diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index ac944d8397a86..19acd194f4ddf 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -21,8 +21,8 @@ if sys.version > '3': basestring = str -from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \ - LogisticRegressionSummary +from pyspark.ml.classification import BinaryLogisticRegressionSummary, LinearSVC, \ + LinearSVCSummary, LogisticRegression, LogisticRegressionSummary from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans from pyspark.ml.linalg import Vectors from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression @@ -193,6 +193,48 @@ def test_multiclass_logistic_regression_summary(self): self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) + def test_linear_svc_summary(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0, 1.0, 1.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0, 3.0))], + ["label", "weight", "features"]) + svc = LinearSVC(maxIter=5, weightCol="weight") + model = svc.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary() + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.scoreCol, "rawPrediction") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + print(s.weightedTruePositiveRate) + self.assertAlmostEqual(s.weightedTruePositiveRate, 0.5, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.5, 2) + self.assertAlmostEqual(s.weightedRecall, 0.5, 2) + self.assertAlmostEqual(s.weightedPrecision, 0.25, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 0.3333333333333333, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.3333333333333333, 2) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, LinearSVCSummary)) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_gaussian_mixture_summary(self): data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),), (Vectors.sparse(1, [], []),)]