Skip to content

Commit 27e1f38

Browse files
GayathriMuralijkbradley
authored andcommitted
[SPARK-13034] PySpark ml.classification support export/import
## What changes were proposed in this pull request? Add export/import for all estimators and transformers(which have Scala implementation) under pyspark/ml/classification.py. ## How was this patch tested? ./python/run-tests ./dev/lint-python Unit tests added to check persistence in Logistic Regression Author: GayathriMurali <[email protected]> Closes #11707 from GayathriMurali/SPARK-13034.
1 parent 85c42fd commit 27e1f38

File tree

2 files changed

+61
-9
lines changed

2 files changed

+61
-9
lines changed

python/pyspark/ml/classification.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import warnings
1919

2020
from pyspark import since
21-
from pyspark.ml.util import keyword_only
21+
from pyspark.ml.util import *
2222
from pyspark.ml.wrapper import JavaEstimator, JavaModel
2323
from pyspark.ml.param.shared import *
2424
from pyspark.ml.regression import (
@@ -38,7 +38,7 @@
3838
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
3939
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
4040
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
41-
HasWeightCol):
41+
HasWeightCol, MLWritable, MLReadable):
4242
"""
4343
Logistic regression.
4444
Currently, this class only supports binary classification.
@@ -69,6 +69,18 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
6969
Traceback (most recent call last):
7070
...
7171
TypeError: Method setParams forces keyword arguments.
72+
>>> lr_path = temp_path + "/lr"
73+
>>> lr.save(lr_path)
74+
>>> lr2 = LogisticRegression.load(lr_path)
75+
>>> lr2.getMaxIter()
76+
5
77+
>>> model_path = temp_path + "/lr_model"
78+
>>> model.save(model_path)
79+
>>> model2 = LogisticRegressionModel.load(model_path)
80+
>>> model.coefficients[0] == model2.coefficients[0]
81+
True
82+
>>> model.intercept == model2.intercept
83+
True
7284
7385
.. versionadded:: 1.3.0
7486
"""
@@ -186,7 +198,7 @@ def _checkThresholdConsistency(self):
186198
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
187199

188200

189-
class LogisticRegressionModel(JavaModel):
201+
class LogisticRegressionModel(JavaModel, MLWritable, MLReadable):
190202
"""
191203
Model fitted by LogisticRegression.
192204
@@ -589,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels):
589601

590602
@inherit_doc
591603
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
592-
HasRawPredictionCol):
604+
HasRawPredictionCol, MLWritable, MLReadable):
593605
"""
594606
Naive Bayes Classifiers.
595607
It supports both Multinomial and Bernoulli NB. Multinomial NB
@@ -623,6 +635,18 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
623635
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
624636
>>> model.transform(test1).head().prediction
625637
1.0
638+
>>> nb_path = temp_path + "/nb"
639+
>>> nb.save(nb_path)
640+
>>> nb2 = NaiveBayes.load(nb_path)
641+
>>> nb2.getSmoothing()
642+
1.0
643+
>>> model_path = temp_path + "/nb_model"
644+
>>> model.save(model_path)
645+
>>> model2 = NaiveBayesModel.load(model_path)
646+
>>> model.pi == model2.pi
647+
True
648+
>>> model.theta == model2.theta
649+
True
626650
627651
.. versionadded:: 1.5.0
628652
"""
@@ -696,7 +720,7 @@ def getModelType(self):
696720
return self.getOrDefault(self.modelType)
697721

698722

699-
class NaiveBayesModel(JavaModel):
723+
class NaiveBayesModel(JavaModel, MLWritable, MLReadable):
700724
"""
701725
Model fitted by NaiveBayes.
702726
@@ -853,17 +877,27 @@ def weights(self):
853877

854878
if __name__ == "__main__":
855879
import doctest
880+
import pyspark.ml.classification
856881
from pyspark.context import SparkContext
857882
from pyspark.sql import SQLContext
858-
globs = globals().copy()
883+
globs = pyspark.ml.classification.__dict__.copy()
859884
# The small batch size here ensures that we see multiple batches,
860885
# even in these small test examples:
861886
sc = SparkContext("local[2]", "ml.classification tests")
862887
sqlContext = SQLContext(sc)
863888
globs['sc'] = sc
864889
globs['sqlContext'] = sqlContext
865-
(failure_count, test_count) = doctest.testmod(
866-
globs=globs, optionflags=doctest.ELLIPSIS)
867-
sc.stop()
890+
import tempfile
891+
temp_path = tempfile.mkdtemp()
892+
globs['temp_path'] = temp_path
893+
try:
894+
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
895+
sc.stop()
896+
finally:
897+
from shutil import rmtree
898+
try:
899+
rmtree(temp_path)
900+
except OSError:
901+
pass
868902
if failure_count:
869903
exit(-1)

python/pyspark/ml/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,24 @@ def test_linear_regression(self):
499499
except OSError:
500500
pass
501501

502+
def test_logistic_regression(self):
503+
lr = LogisticRegression(maxIter=1)
504+
path = tempfile.mkdtemp()
505+
lr_path = path + "/logreg"
506+
lr.save(lr_path)
507+
lr2 = LogisticRegression.load(lr_path)
508+
self.assertEqual(lr2.uid, lr2.maxIter.parent,
509+
"Loaded LogisticRegression instance uid (%s) "
510+
"did not match Param's uid (%s)"
511+
% (lr2.uid, lr2.maxIter.parent))
512+
self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
513+
"Loaded LogisticRegression instance default params did not match " +
514+
"original defaults")
515+
try:
516+
rmtree(path)
517+
except OSError:
518+
pass
519+
502520
def test_pipeline_persistence(self):
503521
sqlContext = SQLContext(self.sc)
504522
temp_path = tempfile.mkdtemp()

0 commit comments

Comments
 (0)