Skip to content

Commit e10516a

Browse files
huaxingaosrowen
authored andcommitted
[SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary
### What changes were proposed in this pull request? Return LogisticRegressionSummary for multiclass logistic regression evaluate in PySpark ### Why are the changes needed? Currently we have ``` since("2.0.0") def evaluate(self, dataset): if not isinstance(dataset, DataFrame): raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) ``` we should return LogisticRegressionSummary for multiclass logistic regression ### Does this PR introduce _any_ user-facing change? Yes return LogisticRegressionSummary instead of BinaryLogisticRegressionSummary for multiclass logistic regression in Python ### How was this patch tested? unit test Closes #28503 from huaxingao/lr_summary. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent b2300fc commit e10516a

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

python/pyspark/ml/classification.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,10 @@ def evaluate(self, dataset):
932932
if not isinstance(dataset, DataFrame):
933933
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
934934
java_blr_summary = self._call_java("evaluate", dataset)
935-
return BinaryLogisticRegressionSummary(java_blr_summary)
935+
if self.numClasses <= 2:
936+
return BinaryLogisticRegressionSummary(java_blr_summary)
937+
else:
938+
return LogisticRegressionSummary(java_blr_summary)
936939

937940

938941
class LogisticRegressionSummary(JavaWrapper):

python/pyspark/ml/tests/test_training_summary.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
if sys.version > '3':
2222
basestring = str
2323

24-
from pyspark.ml.classification import LogisticRegression
24+
from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \
25+
LogisticRegressionSummary
2526
from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
2627
from pyspark.ml.linalg import Vectors
2728
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
@@ -149,6 +150,7 @@ def test_binary_logistic_regression_summary(self):
149150
# test evaluation (with training dataset) produces a summary with same values
150151
# one check is enough to verify a summary is returned, Scala version runs full test
151152
sameSummary = model.evaluate(df)
153+
self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary))
152154
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
153155

154156
def test_multiclass_logistic_regression_summary(self):
@@ -187,6 +189,8 @@ def test_multiclass_logistic_regression_summary(self):
187189
# test evaluation (with training dataset) produces a summary with same values
188190
# one check is enough to verify a summary is returned, Scala version runs full test
189191
sameSummary = model.evaluate(df)
192+
self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary))
193+
self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary))
190194
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
191195

192196
def test_gaussian_mixture_summary(self):

0 commit comments

Comments
 (0)