Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.common import inherit_doc
from pyspark.ml.util import JavaMLReadable, JavaMLWritable

__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
'MulticlassClassificationEvaluator']
Expand Down Expand Up @@ -103,7 +104,8 @@ def isLargerBetter(self):


@inherit_doc
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol):
class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Expand All @@ -121,6 +123,11 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
0.70...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
0.83...
>>> bce_path = temp_path + "/bce"
>>> evaluator.save(bce_path)
>>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)
>>> str(evaluator2.getRawPredictionCol())
'raw'

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -172,7 +179,8 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",


@inherit_doc
class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Expand All @@ -190,6 +198,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
0.993...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
2.649...
>>> re_path = temp_path + "/re"
>>> evaluator.save(re_path)
>>> evaluator2 = RegressionEvaluator.load(re_path)
>>> str(evaluator2.getPredictionCol())
'raw'

.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -244,7 +257,8 @@ def setParams(self, predictionCol="prediction", labelCol="label",


@inherit_doc
class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Expand All @@ -260,6 +274,11 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
0.66...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
0.66...
>>> mce_path = temp_path + "/mce"
>>> evaluator.save(mce_path)
>>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)
>>> str(evaluator2.getPredictionCol())
'prediction'

.. versionadded:: 1.5.0
"""
Expand Down Expand Up @@ -311,19 +330,27 @@ def setParams(self, predictionCol="prediction", labelCol="label",

if __name__ == "__main__":
import doctest
import tempfile
import pyspark.ml.evaluation
from pyspark.sql import SparkSession
globs = globals().copy()
globs = pyspark.ml.evaluation.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
spark = SparkSession.builder\
.master("local[2]")\
.appName("ml.evaluation tests")\
.getOrCreate()
sc = spark.sparkContext
globs['sc'] = sc
globs['spark'] = spark
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count:
exit(-1)