Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,10 @@ def evaluate(self, dataset):
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)
if self.numClasses <= 2:
return BinaryLogisticRegressionSummary(java_blr_summary)
else:
return LogisticRegressionSummary(java_blr_summary)


class LogisticRegressionSummary(JavaWrapper):
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/ml/tests/test_training_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
if sys.version > '3':
basestring = str

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \
LogisticRegressionSummary
from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
from pyspark.ml.linalg import Vectors
from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
Expand Down Expand Up @@ -149,6 +150,7 @@ def test_binary_logistic_regression_summary(self):
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_multiclass_logistic_regression_summary(self):
Expand Down Expand Up @@ -187,6 +189,8 @@ def test_multiclass_logistic_regression_summary(self):
# test evaluation (with training dataset) produces a summary with same values
# one check is enough to verify a summary is returned, Scala version runs full test
sameSummary = model.evaluate(df)
self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary))
self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary))
self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)

def test_gaussian_mixture_summary(self):
Expand Down