diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala new file mode 100644 index 0000000000000..e9ea38161d3c0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ClassificationSummary.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.DoubleType + + +/** + * Abstraction for multiclass classification results for a given model. + */ +private[classification] trait ClassificationSummary extends Serializable { + + /** + * Dataframe output by the model's `transform` method. + */ + @Since("3.1.0") + def predictions: DataFrame + + /** Field in "predictions" which gives the prediction of each class. */ + @Since("3.1.0") + def predictionCol: String + + /** Field in "predictions" which gives the true label of each instance (if available). */ + @Since("3.1.0") + def labelCol: String + + /** Field in "predictions" which gives the weight of each instance as a vector. */ + @Since("3.1.0") + def weightCol: String + + @transient private val multiclassMetrics = { + val weightColumn = if (predictions.schema.fieldNames.contains(weightCol)) { + col(weightCol).cast(DoubleType) + } else { + lit(1.0) + } + new MulticlassMetrics( + predictions.select(col(predictionCol), col(labelCol).cast(DoubleType), weightColumn) + .rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) + }) + } + + /** + * Returns the sequence of labels in ascending order. This order matches the order used + * in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel. + * + * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the + * training set is missing a label, then all of the arrays over labels + * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the + * expected numClasses. + */ + @Since("3.1.0") + def labels: Array[Double] = multiclassMetrics.labels + + /** Returns true positive rate for each label (category). */ + @Since("3.1.0") + def truePositiveRateByLabel: Array[Double] = recallByLabel + + /** Returns false positive rate for each label (category). */ + @Since("3.1.0") + def falsePositiveRateByLabel: Array[Double] = { + multiclassMetrics.labels.map(label => multiclassMetrics.falsePositiveRate(label)) + } + + /** Returns precision for each label (category). */ + @Since("3.1.0") + def precisionByLabel: Array[Double] = { + multiclassMetrics.labels.map(label => multiclassMetrics.precision(label)) + } + + /** Returns recall for each label (category). */ + @Since("3.1.0") + def recallByLabel: Array[Double] = { + multiclassMetrics.labels.map(label => multiclassMetrics.recall(label)) + } + + /** Returns f-measure for each label (category). */ + @Since("3.1.0") + def fMeasureByLabel(beta: Double): Array[Double] = { + multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label, beta)) + } + + /** Returns f1-measure for each label (category). */ + @Since("3.1.0") + def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0) + + /** + * Returns accuracy. + * (equals to the total number of correctly classified instances + * out of the total number of instances.) + */ + @Since("3.1.0") + def accuracy: Double = multiclassMetrics.accuracy + + /** + * Returns weighted true positive rate. + * (equals to precision, recall and f-measure) + */ + @Since("3.1.0") + def weightedTruePositiveRate: Double = weightedRecall + + /** Returns weighted false positive rate. */ + @Since("3.1.0") + def weightedFalsePositiveRate: Double = multiclassMetrics.weightedFalsePositiveRate + + /** + * Returns weighted averaged recall. + * (equals to precision, recall and f-measure) + */ + @Since("3.1.0") + def weightedRecall: Double = multiclassMetrics.weightedRecall + + /** Returns weighted averaged precision. */ + @Since("3.1.0") + def weightedPrecision: Double = multiclassMetrics.weightedPrecision + + /** Returns weighted averaged f-measure. */ + @Since("3.1.0") + def weightedFMeasure(beta: Double): Double = multiclassMetrics.weightedFMeasure(beta) + + /** Returns weighted averaged f1-measure. */ + @Since("3.1.0") + def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0) +} + +/** + * Abstraction for training results. + */ +private[classification] trait TrainingSummary { + + /** + * objective function (scaled loss + regularization) at each iteration. + * It contains one more element, the initial state, than number of iterations. + */ + @Since("3.1.0") + def objectiveHistory: Array[Double] + + /** Number of training iterations. */ + @Since("3.1.0") + def totalIterations: Int = { + assert(objectiveHistory.length > 0, "objectiveHistory length should be greater than 0.") + objectiveHistory.length - 1 + } +} + +/** + * Abstraction for binary classification results for a given model. + */ +private[classification] trait BinaryClassificationSummary extends ClassificationSummary { + + private val sparkSession = predictions.sparkSession + import sparkSession.implicits._ + + /** + * Field in "predictions" which gives the probability or rawPrediction of each class as a + * vector. + */ + def scoreCol: String = null + + @transient private val binaryMetrics = { + val weightColumn = if (predictions.schema.fieldNames.contains(weightCol)) { + col(weightCol).cast(DoubleType) + } else { + lit(1.0) + } + + // TODO: Allow the user to vary the number of bins using a setBins method in + // BinaryClassificationMetrics. For now the default is set to 1000. + new BinaryClassificationMetrics( + predictions.select(col(scoreCol), col(labelCol).cast(DoubleType), weightColumn).rdd.map { + case Row(score: Vector, label: Double, weight: Double) => (score(1), label, weight) + }, 1000 + ) + } + + /** + * Returns the receiver operating characteristic (ROC) curve, + * which is a Dataframe having two fields (FPR, TPR) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic + */ + @Since("3.1.0") + @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") + + /** + * Computes the area under the receiver operating characteristic (ROC) curve. + */ + @Since("3.1.0") + lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() + + /** + * Returns the precision-recall curve, which is a Dataframe containing + * two fields recall, precision with (0.0, 1.0) prepended to it. + */ + @Since("3.1.0") + @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") + + /** + * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + */ + @Since("3.1.0") + @transient lazy val fMeasureByThreshold: DataFrame = { + binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") + } + + /** + * Returns a dataframe with two fields (threshold, precision) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the precision. + */ + @Since("3.1.0") + @transient lazy val precisionByThreshold: DataFrame = { + binaryMetrics.precisionByThreshold().toDF("threshold", "precision") + } + + /** + * Returns a dataframe with two fields (threshold, recall) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the recall. + */ + @Since("3.1.0") + @transient lazy val recallByThreshold: DataFrame = { + binaryMetrics.recallByThreshold().toDF("threshold", "recall") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 1f5976c59235b..20d619334f7b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,7 +29,6 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.feature._ -import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg._ import org.apache.spark.ml.optim.aggregator._ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} @@ -38,12 +37,10 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.stat._ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented -import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, lit} -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils @@ -1396,136 +1393,16 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { /** * Abstraction for logistic regression results for a given model. */ -sealed trait LogisticRegressionSummary extends Serializable { - - /** - * Dataframe output by the model's `transform` method. - */ - @Since("1.5.0") - def predictions: DataFrame +sealed trait LogisticRegressionSummary extends ClassificationSummary { /** Field in "predictions" which gives the probability of each class as a vector. */ @Since("1.5.0") def probabilityCol: String - /** Field in "predictions" which gives the prediction of each class. */ - @Since("2.3.0") - def predictionCol: String - - /** Field in "predictions" which gives the true label of each instance (if available). */ - @Since("1.5.0") - def labelCol: String - /** Field in "predictions" which gives the features of each instance as a vector. */ @Since("1.6.0") def featuresCol: String - /** Field in "predictions" which gives the weight of each instance as a vector. */ - @Since("3.1.0") - def weightCol: String - - @transient private val multiclassMetrics = { - if (predictions.schema.fieldNames.contains(weightCol)) { - new MulticlassMetrics( - predictions.select( - col(predictionCol), - col(labelCol).cast(DoubleType), - checkNonNegativeWeight(col(weightCol).cast(DoubleType))).rdd.map { - case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) - }) - } else { - new MulticlassMetrics( - predictions.select( - col(predictionCol), - col(labelCol).cast(DoubleType), - lit(1.0)).rdd.map { - case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) - }) - } - } - - /** - * Returns the sequence of labels in ascending order. This order matches the order used - * in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel. - * - * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the - * training set is missing a label, then all of the arrays over labels - * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the - * expected numClasses. - */ - @Since("2.3.0") - def labels: Array[Double] = multiclassMetrics.labels - - /** Returns true positive rate for each label (category). */ - @Since("2.3.0") - def truePositiveRateByLabel: Array[Double] = recallByLabel - - /** Returns false positive rate for each label (category). */ - @Since("2.3.0") - def falsePositiveRateByLabel: Array[Double] = { - multiclassMetrics.labels.map(label => multiclassMetrics.falsePositiveRate(label)) - } - - /** Returns precision for each label (category). */ - @Since("2.3.0") - def precisionByLabel: Array[Double] = { - multiclassMetrics.labels.map(label => multiclassMetrics.precision(label)) - } - - /** Returns recall for each label (category). */ - @Since("2.3.0") - def recallByLabel: Array[Double] = { - multiclassMetrics.labels.map(label => multiclassMetrics.recall(label)) - } - - /** Returns f-measure for each label (category). */ - @Since("2.3.0") - def fMeasureByLabel(beta: Double): Array[Double] = { - multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label, beta)) - } - - /** Returns f1-measure for each label (category). */ - @Since("2.3.0") - def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0) - - /** - * Returns accuracy. - * (equals to the total number of correctly classified instances - * out of the total number of instances.) - */ - @Since("2.3.0") - def accuracy: Double = multiclassMetrics.accuracy - - /** - * Returns weighted true positive rate. - * (equals to precision, recall and f-measure) - */ - @Since("2.3.0") - def weightedTruePositiveRate: Double = weightedRecall - - /** Returns weighted false positive rate. */ - @Since("2.3.0") - def weightedFalsePositiveRate: Double = multiclassMetrics.weightedFalsePositiveRate - - /** - * Returns weighted averaged recall. - * (equals to precision, recall and f-measure) - */ - @Since("2.3.0") - def weightedRecall: Double = multiclassMetrics.weightedRecall - - /** Returns weighted averaged precision. */ - @Since("2.3.0") - def weightedPrecision: Double = multiclassMetrics.weightedPrecision - - /** Returns weighted averaged f-measure. */ - @Since("2.3.0") - def weightedFMeasure(beta: Double): Double = multiclassMetrics.weightedFMeasure(beta) - - /** Returns weighted averaged f1-measure. */ - @Since("2.3.0") - def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0) - /** * Convenient method for casting to binary logistic regression summary. * This method will throw an Exception if the summary is not a binary summary. @@ -1540,101 +1417,21 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * Abstraction for multiclass logistic regression training results. - * Currently, the training summary ignores the training weights except - * for the objective trace. */ -sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { - - /** - * objective function (scaled loss + regularization) at each iteration. - * It contains one more element, the initial state, than number of iterations. - */ - @Since("1.5.0") - def objectiveHistory: Array[Double] - - /** Number of training iterations. */ - @Since("1.5.0") - def totalIterations: Int = { - assert(objectiveHistory.length > 0, s"objectiveHistory length should be greater than 1.") - objectiveHistory.length - 1 - } - +sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary + with TrainingSummary { } /** * Abstraction for binary logistic regression results for a given model. */ -sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary { - - private val sparkSession = predictions.sparkSession - import sparkSession.implicits._ - - // TODO: Allow the user to vary the number of bins using a setBins method in - // BinaryClassificationMetrics. For now the default is set to 100. - @transient private val binaryMetrics = if (predictions.schema.fieldNames.contains(weightCol)) { - new BinaryClassificationMetrics( - predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType), - checkNonNegativeWeight(col(weightCol).cast(DoubleType))).rdd.map { - case Row(score: Vector, label: Double, weight: Double) => (score(1), label, weight) - }, 100 - ) - } else { - new BinaryClassificationMetrics( - predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType), - lit(1.0)).rdd.map { - case Row(score: Vector, label: Double, weight: Double) => (score(1), label, weight) - }, 100 - ) - } +sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary + with BinaryClassificationSummary { - /** - * Returns the receiver operating characteristic (ROC) curve, - * which is a Dataframe having two fields (FPR, TPR) - * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. - * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic - */ - @Since("1.5.0") - @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") - - /** - * Computes the area under the receiver operating characteristic (ROC) curve. - */ - @Since("1.5.0") - lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() - - /** - * Returns the precision-recall curve, which is a Dataframe containing - * two fields recall, precision with (0.0, 1.0) prepended to it. - */ - @Since("1.5.0") - @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") - - /** - * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. - */ - @Since("1.5.0") - @transient lazy val fMeasureByThreshold: DataFrame = { - binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") - } - - /** - * Returns a dataframe with two fields (threshold, precision) curve. - * Every possible probability obtained in transforming the dataset are used - * as thresholds used in calculating the precision. - */ - @Since("1.5.0") - @transient lazy val precisionByThreshold: DataFrame = { - binaryMetrics.precisionByThreshold().toDF("threshold", "precision") - } - - /** - * Returns a dataframe with two fields (threshold, recall) curve. - * Every possible probability obtained in transforming the dataset are used - * as thresholds used in calculating the recall. - */ - @Since("1.5.0") - @transient lazy val recallByThreshold: DataFrame = { - binaryMetrics.recallByThreshold().toDF("threshold", "recall") + override def scoreCol: String = if (probabilityCol.nonEmpty) { + probabilityCol + } else { + throw new SparkException("probabilityCol is required for BinaryLogisticRegressionSummary.") } } @@ -1674,7 +1471,7 @@ private class LogisticRegressionTrainingSummaryImpl( * * @param predictions dataframe output by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the probability of - * each class as a vector. + * each class as a vector. * @param predictionCol field in "predictions" which gives the prediction for a data instance as a * double. * @param labelCol field in "predictions" which gives the true label of each instance. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 30c21d8b06670..ecee531c88a8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -313,12 +313,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(mlorModel2.summary.isInstanceOf[LogisticRegressionTrainingSummary]) withClue("cannot get binary summary for multiclass model") { intercept[RuntimeException] { - mlorModel.binarySummary + mlorModel2.binarySummary } } withClue("cannot cast summary to binary summary multiclass model") { intercept[RuntimeException] { - mlorModel.summary.asBinary + mlorModel2.summary.asBinary } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index addb2d8152189..0be7b4c1003a7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -39,6 +39,7 @@ object MimaExcludes { // [SPARK-31077] Remove ChiSqSelector dependency on mllib.ChiSqSelectorModel // private constructor ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.this"), + // [SPARK-31127] Implement abstract Selector // org.apache.spark.ml.feature.ChiSqSelectorModel type hierarchy change // before: class ChiSqSelector extends Estimator with ChiSqSelectorParams @@ -46,11 +47,31 @@ object MimaExcludes { // false positive, no binary incompatibility ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelector"), - //[SPARK-31840] Add instance weight support in LogisticRegressionSummary - // weightCol in org.apache.spark.ml.classification.LogisticRegressionSummary is present only in current version - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightCol"), + // [SPARK-24634] Add a new metric regarding number of inputs later than watermark plus allowed delay - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.$default$4") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.$default$4"), + + //[SPARK-31893] Add a generic ClassificationSummary trait + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.weightCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.weightCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.weightCol") ) // Exclude rules for 3.0.x diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 3f3699ce53b51..ff506066519cd 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -240,6 +240,259 @@ def predictProbability(self, value): return self._call_java("predictProbability", value) +@inherit_doc +class _ClassificationSummary(JavaWrapper): + """ + Abstraction for multiclass classification results for a given model. + + .. versionadded:: 3.1.0 + """ + + @property + @since("3.1.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("3.1.0") + def predictionCol(self): + """ + Field in "predictions" which gives the prediction of each class. + """ + return self._call_java("predictionCol") + + @property + @since("3.1.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("3.1.0") + def weightCol(self): + """ + Field in "predictions" which gives the weight of each instance + as a vector. + """ + return self._call_java("weightCol") + + @property + @since("3.1.0") + def labels(self): + """ + Returns the sequence of labels in ascending order. This order matches the order used + in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel. + + Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the + training set is missing a label, then all of the arrays over labels + (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the + expected numClasses. + """ + return self._call_java("labels") + + @property + @since("3.1.0") + def truePositiveRateByLabel(self): + """ + Returns true positive rate for each label (category). + """ + return self._call_java("truePositiveRateByLabel") + + @property + @since("3.1.0") + def falsePositiveRateByLabel(self): + """ + Returns false positive rate for each label (category). + """ + return self._call_java("falsePositiveRateByLabel") + + @property + @since("3.1.0") + def precisionByLabel(self): + """ + Returns precision for each label (category). + """ + return self._call_java("precisionByLabel") + + @property + @since("3.1.0") + def recallByLabel(self): + """ + Returns recall for each label (category). + """ + return self._call_java("recallByLabel") + + @since("3.1.0") + def fMeasureByLabel(self, beta=1.0): + """ + Returns f-measure for each label (category). + """ + return self._call_java("fMeasureByLabel", beta) + + @property + @since("3.1.0") + def accuracy(self): + """ + Returns accuracy. + (equals to the total number of correctly classified instances + out of the total number of instances.) + """ + return self._call_java("accuracy") + + @property + @since("3.1.0") + def weightedTruePositiveRate(self): + """ + Returns weighted true positive rate. + (equals to precision, recall and f-measure) + """ + return self._call_java("weightedTruePositiveRate") + + @property + @since("3.1.0") + def weightedFalsePositiveRate(self): + """ + Returns weighted false positive rate. + """ + return self._call_java("weightedFalsePositiveRate") + + @property + @since("3.1.0") + def weightedRecall(self): + """ + Returns weighted averaged recall. + (equals to precision, recall and f-measure) + """ + return self._call_java("weightedRecall") + + @property + @since("3.1.0") + def weightedPrecision(self): + """ + Returns weighted averaged precision. + """ + return self._call_java("weightedPrecision") + + @since("3.1.0") + def weightedFMeasure(self, beta=1.0): + """ + Returns weighted averaged f-measure. + """ + return self._call_java("weightedFMeasure", beta) + + +@inherit_doc +class _TrainingSummary(JavaWrapper): + """ + Abstraction for Training results. + + .. versionadded:: 3.1.0 + """ + + @property + @since("3.1.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. It contains one more element, the initial state, + than number of iterations. + """ + return self._call_java("objectiveHistory") + + @property + @since("3.1.0") + def totalIterations(self): + """ + Number of training iterations until termination. + """ + return self._call_java("totalIterations") + + +@inherit_doc +class _BinaryClassificationSummary(_ClassificationSummary): + """ + Binary classification results for a given model. + + .. versionadded:: 3.1.0 + """ + + @property + @since("3.1.0") + def scoreCol(self): + """ + Field in "predictions" which gives the probability or raw prediction + of each class as a vector. + """ + return self._call_java("scoreCol") + + @property + @since("3.1.0") + def roc(self): + """ + Returns the receiver operating characteristic (ROC) curve, + which is a Dataframe having two fields (FPR, TPR) with + (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + + .. seealso:: `Wikipedia reference + `_ + """ + return self._call_java("roc") + + @property + @since("3.1.0") + def areaUnderROC(self): + """ + Computes the area under the receiver operating characteristic + (ROC) curve. + """ + return self._call_java("areaUnderROC") + + @property + @since("3.1.0") + def pr(self): + """ + Returns the precision-recall curve, which is a Dataframe + containing two fields recall, precision with (0.0, 1.0) prepended + to it. + """ + return self._call_java("pr") + + @property + @since("3.1.0") + def fMeasureByThreshold(self): + """ + Returns a dataframe with two fields (threshold, F-Measure) curve + with beta = 1.0. + """ + return self._call_java("fMeasureByThreshold") + + @property + @since("3.1.0") + def precisionByThreshold(self): + """ + Returns a dataframe with two fields (threshold, precision) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the precision. + """ + return self._call_java("precisionByThreshold") + + @property + @since("3.1.0") + def recallByThreshold(self): + """ + Returns a dataframe with two fields (threshold, recall) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the recall. + """ + return self._call_java("recallByThreshold") + + class _LinearSVCParams(_ClassifierParams, HasRegParam, HasMaxIter, HasFitIntercept, HasTol, HasStandardization, HasWeightCol, HasAggregationDepth, HasThreshold, HasBlockSize): @@ -940,21 +1193,13 @@ def evaluate(self, dataset): return LogisticRegressionSummary(java_blr_summary) -class LogisticRegressionSummary(JavaWrapper): +class LogisticRegressionSummary(_ClassificationSummary): """ Abstraction for Logistic Regression Results for a given model. .. versionadded:: 2.0.0 """ - @property - @since("2.0.0") - def predictions(self): - """ - Dataframe outputted by the model's `transform` method. - """ - return self._call_java("predictions") - @property @since("2.0.0") def probabilityCol(self): @@ -964,23 +1209,6 @@ def probabilityCol(self): """ return self._call_java("probabilityCol") - @property - @since("2.3.0") - def predictionCol(self): - """ - Field in "predictions" which gives the prediction of each class. - """ - return self._call_java("predictionCol") - - @property - @since("2.0.0") - def labelCol(self): - """ - Field in "predictions" which gives the true label of each - instance. - """ - return self._call_java("labelCol") - @property @since("2.0.0") def featuresCol(self): @@ -990,241 +1218,26 @@ def featuresCol(self): """ return self._call_java("featuresCol") - @property - @since("3.1.0") - def weightCol(self): - """ - Field in "predictions" which gives the weight of each instance - as a vector. - """ - return self._call_java("weightCol") - - @property - @since("2.3.0") - def labels(self): - """ - Returns the sequence of labels in ascending order. This order matches the order used - in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel. - - Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the - training set is missing a label, then all of the arrays over labels - (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the - expected numClasses. - """ - return self._call_java("labels") - - @property - @since("2.3.0") - def truePositiveRateByLabel(self): - """ - Returns true positive rate for each label (category). - """ - return self._call_java("truePositiveRateByLabel") - - @property - @since("2.3.0") - def falsePositiveRateByLabel(self): - """ - Returns false positive rate for each label (category). - """ - return self._call_java("falsePositiveRateByLabel") - - @property - @since("2.3.0") - def precisionByLabel(self): - """ - Returns precision for each label (category). - """ - return self._call_java("precisionByLabel") - - @property - @since("2.3.0") - def recallByLabel(self): - """ - Returns recall for each label (category). - """ - return self._call_java("recallByLabel") - - @since("2.3.0") - def fMeasureByLabel(self, beta=1.0): - """ - Returns f-measure for each label (category). - """ - return self._call_java("fMeasureByLabel", beta) - - @property - @since("2.3.0") - def accuracy(self): - """ - Returns accuracy. - (equals to the total number of correctly classified instances - out of the total number of instances.) - """ - return self._call_java("accuracy") - - @property - @since("2.3.0") - def weightedTruePositiveRate(self): - """ - Returns weighted true positive rate. - (equals to precision, recall and f-measure) - """ - return self._call_java("weightedTruePositiveRate") - - @property - @since("2.3.0") - def weightedFalsePositiveRate(self): - """ - Returns weighted false positive rate. - """ - return self._call_java("weightedFalsePositiveRate") - - @property - @since("2.3.0") - def weightedRecall(self): - """ - Returns weighted averaged recall. - (equals to precision, recall and f-measure) - """ - return self._call_java("weightedRecall") - - @property - @since("2.3.0") - def weightedPrecision(self): - """ - Returns weighted averaged precision. - """ - return self._call_java("weightedPrecision") - - @since("2.3.0") - def weightedFMeasure(self, beta=1.0): - """ - Returns weighted averaged f-measure. - """ - return self._call_java("weightedFMeasure", beta) - @inherit_doc -class LogisticRegressionTrainingSummary(LogisticRegressionSummary): +class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSummary): """ Abstraction for multinomial Logistic Regression Training results. - Currently, the training summary ignores the training weights except - for the objective trace. .. versionadded:: 2.0.0 """ - - @property - @since("2.0.0") - def objectiveHistory(self): - """ - Objective function (scaled loss + regularization) at each - iteration. It contains one more element, the initial state, - than number of iterations. - """ - return self._call_java("objectiveHistory") - - @property - @since("2.0.0") - def totalIterations(self): - """ - Number of training iterations until termination. - """ - return self._call_java("totalIterations") + pass @inherit_doc -class BinaryLogisticRegressionSummary(LogisticRegressionSummary): +class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, + LogisticRegressionSummary): """ Binary Logistic regression results for a given model. .. versionadded:: 2.0.0 """ - - @property - @since("2.0.0") - def roc(self): - """ - Returns the receiver operating characteristic (ROC) curve, - which is a Dataframe having two fields (FPR, TPR) with - (0.0, 0.0) prepended and (1.0, 1.0) appended to it. - - .. seealso:: `Wikipedia reference - `_ - - .. note:: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. - """ - return self._call_java("roc") - - @property - @since("2.0.0") - def areaUnderROC(self): - """ - Computes the area under the receiver operating characteristic - (ROC) curve. - - .. note:: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. - """ - return self._call_java("areaUnderROC") - - @property - @since("2.0.0") - def pr(self): - """ - Returns the precision-recall curve, which is a Dataframe - containing two fields recall, precision with (0.0, 1.0) prepended - to it. - - .. note:: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. - """ - return self._call_java("pr") - - @property - @since("2.0.0") - def fMeasureByThreshold(self): - """ - Returns a dataframe with two fields (threshold, F-Measure) curve - with beta = 1.0. - - .. note:: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. - """ - return self._call_java("fMeasureByThreshold") - - @property - @since("2.0.0") - def precisionByThreshold(self): - """ - Returns a dataframe with two fields (threshold, precision) curve. - Every possible probability obtained in transforming the dataset - are used as thresholds used in calculating the precision. - - .. note:: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. - """ - return self._call_java("precisionByThreshold") - - @property - @since("2.0.0") - def recallByThreshold(self): - """ - Returns a dataframe with two fields (threshold, recall) curve. - Every possible probability obtained in transforming the dataset - are used as thresholds used in calculating the recall. - - .. note:: This ignores instance weights (setting all to 1.0) from - `LogisticRegression.weightCol`. This will change in later Spark - versions. - """ - return self._call_java("recallByThreshold") + pass @inherit_doc