diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d635be1d8db8..3bc862cc42af 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -932,7 +932,10 @@ def evaluate(self, dataset): if not isinstance(dataset, DataFrame): raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) java_blr_summary = self._call_java("evaluate", dataset) - return BinaryLogisticRegressionSummary(java_blr_summary) + if self.numClasses <= 2: + return BinaryLogisticRegressionSummary(java_blr_summary) + else: + return LogisticRegressionSummary(java_blr_summary) class LogisticRegressionSummary(JavaWrapper): diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index 1d19ebf9a34a..b5054095d190 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -21,7 +21,8 @@ if sys.version > '3': basestring = str -from pyspark.ml.classification import LogisticRegression +from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \ + LogisticRegressionSummary from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans from pyspark.ml.linalg import Vectors from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression @@ -149,6 +150,7 @@ def test_binary_logistic_regression_summary(self): # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) def test_multiclass_logistic_regression_summary(self): @@ -187,6 +189,8 @@ def test_multiclass_logistic_regression_summary(self): # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary)) + self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) def test_gaussian_mixture_summary(self):