|
18 | 18 | import warnings |
19 | 19 |
|
20 | 20 | from pyspark import since |
21 | | -from pyspark.ml.util import keyword_only |
| 21 | +from pyspark.ml.util import * |
22 | 22 | from pyspark.ml.wrapper import JavaEstimator, JavaModel |
23 | 23 | from pyspark.ml.param.shared import * |
24 | 24 | from pyspark.ml.regression import ( |
|
38 | 38 | class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, |
39 | 39 | HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, |
40 | 40 | HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, |
41 | | - HasWeightCol): |
| 41 | + HasWeightCol, MLWritable, MLReadable): |
42 | 42 | """ |
43 | 43 | Logistic regression. |
44 | 44 | Currently, this class only supports binary classification. |
@@ -69,6 +69,18 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti |
69 | 69 | Traceback (most recent call last): |
70 | 70 | ... |
71 | 71 | TypeError: Method setParams forces keyword arguments. |
| 72 | + >>> lr_path = temp_path + "/lr" |
| 73 | + >>> lr.save(lr_path) |
| 74 | + >>> lr2 = LogisticRegression.load(lr_path) |
| 75 | + >>> lr2.getMaxIter() |
| 76 | + 5 |
| 77 | + >>> model_path = temp_path + "/lr_model" |
| 78 | + >>> model.save(model_path) |
| 79 | + >>> model2 = LogisticRegressionModel.load(model_path) |
| 80 | + >>> model.coefficients[0] == model2.coefficients[0] |
| 81 | + True |
| 82 | + >>> model.intercept == model2.intercept |
| 83 | + True |
72 | 84 |
|
73 | 85 | .. versionadded:: 1.3.0 |
74 | 86 | """ |
@@ -186,7 +198,7 @@ def _checkThresholdConsistency(self): |
186 | 198 | " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) |
187 | 199 |
|
188 | 200 |
|
189 | | -class LogisticRegressionModel(JavaModel): |
| 201 | +class LogisticRegressionModel(JavaModel, MLWritable, MLReadable): |
190 | 202 | """ |
191 | 203 | Model fitted by LogisticRegression. |
192 | 204 |
|
@@ -589,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels): |
589 | 601 |
|
590 | 602 | @inherit_doc |
591 | 603 | class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, |
592 | | - HasRawPredictionCol): |
| 604 | + HasRawPredictionCol, MLWritable, MLReadable): |
593 | 605 | """ |
594 | 606 | Naive Bayes Classifiers. |
595 | 607 | It supports both Multinomial and Bernoulli NB. Multinomial NB |
@@ -623,6 +635,18 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H |
623 | 635 | >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() |
624 | 636 | >>> model.transform(test1).head().prediction |
625 | 637 | 1.0 |
| 638 | + >>> nb_path = temp_path + "/nb" |
| 639 | + >>> nb.save(nb_path) |
| 640 | + >>> nb2 = NaiveBayes.load(nb_path) |
| 641 | + >>> nb2.getSmoothing() |
| 642 | + 1.0 |
| 643 | + >>> model_path = temp_path + "/nb_model" |
| 644 | + >>> model.save(model_path) |
| 645 | + >>> model2 = NaiveBayesModel.load(model_path) |
| 646 | + >>> model.pi == model2.pi |
| 647 | + True |
| 648 | + >>> model.theta == model2.theta |
| 649 | + True |
626 | 650 |
|
627 | 651 | .. versionadded:: 1.5.0 |
628 | 652 | """ |
@@ -696,7 +720,7 @@ def getModelType(self): |
696 | 720 | return self.getOrDefault(self.modelType) |
697 | 721 |
|
698 | 722 |
|
699 | | -class NaiveBayesModel(JavaModel): |
| 723 | +class NaiveBayesModel(JavaModel, MLWritable, MLReadable): |
700 | 724 | """ |
701 | 725 | Model fitted by NaiveBayes. |
702 | 726 |
|
@@ -853,17 +877,27 @@ def weights(self): |
853 | 877 |
|
854 | 878 | if __name__ == "__main__": |
855 | 879 | import doctest |
| 880 | + import pyspark.ml.classification |
856 | 881 | from pyspark.context import SparkContext |
857 | 882 | from pyspark.sql import SQLContext |
858 | | - globs = globals().copy() |
| 883 | + globs = pyspark.ml.classification.__dict__.copy() |
859 | 884 | # The small batch size here ensures that we see multiple batches, |
860 | 885 | # even in these small test examples: |
861 | 886 | sc = SparkContext("local[2]", "ml.classification tests") |
862 | 887 | sqlContext = SQLContext(sc) |
863 | 888 | globs['sc'] = sc |
864 | 889 | globs['sqlContext'] = sqlContext |
865 | | - (failure_count, test_count) = doctest.testmod( |
866 | | - globs=globs, optionflags=doctest.ELLIPSIS) |
867 | | - sc.stop() |
| 890 | + import tempfile |
| 891 | + temp_path = tempfile.mkdtemp() |
| 892 | + globs['temp_path'] = temp_path |
| 893 | + try: |
| 894 | + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) |
| 895 | + sc.stop() |
| 896 | + finally: |
| 897 | + from shutil import rmtree |
| 898 | + try: |
| 899 | + rmtree(temp_path) |
| 900 | + except OSError: |
| 901 | + pass |
868 | 902 | if failure_count: |
869 | 903 | exit(-1) |
0 commit comments