Skip to content

Commit a95a4af

Browse files
committed
[SPARK-23120][PYSPARK][ML] Add basic PMML export support to PySpark
## What changes were proposed in this pull request? Adds basic PMML export support for Spark ML stages to PySpark as was previously done in Scala. Includes LinearRegressionModel as the first stage to implement. ## How was this patch tested? Doctest, the main testing work for this is on the Scala side. (TODO holden add the unittest once I finish locally). Author: Holden Karau <[email protected]> Closes #21172 from holdenk/SPARK-23120-add-pmml-export-support-to-pyspark.
1 parent 524827f commit a95a4af

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

python/pyspark/ml/regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
9595
True
9696
>>> model.numFeatures
9797
1
98+
>>> model.write().format("pmml").save(model_path + "_2")
9899
99100
.. versionadded:: 1.4.0
100101
"""
@@ -161,7 +162,7 @@ def getEpsilon(self):
161162
return self.getOrDefault(self.epsilon)
162163

163164

164-
class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
165+
class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable):
165166
"""
166167
Model fitted by :class:`LinearRegression`.
167168

python/pyspark/ml/tests.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,23 @@ def test_linear_regression(self):
13621362
except OSError:
13631363
pass
13641364

1365+
def test_linear_regression_pmml_basic(self):
1366+
# Most of the validation is done in the Scala side, here we just check
1367+
# that we output text rather than parquet (e.g. that the format flag
1368+
# was respected).
1369+
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
1370+
(0.0, 2.0, Vectors.sparse(1, [], []))],
1371+
["label", "weight", "features"])
1372+
lr = LinearRegression(maxIter=1)
1373+
model = lr.fit(df)
1374+
path = tempfile.mkdtemp()
1375+
lr_path = path + "/lr-pmml"
1376+
model.write().format("pmml").save(lr_path)
1377+
pmml_text_list = self.sc.textFile(lr_path).collect()
1378+
pmml_text = "\n".join(pmml_text_list)
1379+
self.assertIn("Apache Spark", pmml_text)
1380+
self.assertIn("PMML", pmml_text)
1381+
13651382
def test_logistic_regression(self):
13661383
lr = LogisticRegression(maxIter=1)
13671384
path = tempfile.mkdtemp()

python/pyspark/ml/util.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,23 @@ def overwrite(self):
148148
return self
149149

150150

151+
@inherit_doc
152+
class GeneralMLWriter(MLWriter):
153+
"""
154+
Utility class that can save ML instances in different formats.
155+
156+
.. versionadded:: 2.4.0
157+
"""
158+
159+
def format(self, source):
160+
"""
161+
Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
162+
name for export).
163+
"""
164+
self.source = source
165+
return self
166+
167+
151168
@inherit_doc
152169
class JavaMLWriter(MLWriter):
153170
"""
@@ -192,6 +209,24 @@ def session(self, sparkSession):
192209
return self
193210

194211

212+
@inherit_doc
213+
class GeneralJavaMLWriter(JavaMLWriter):
214+
"""
215+
(Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types
216+
"""
217+
218+
def __init__(self, instance):
219+
super(GeneralJavaMLWriter, self).__init__(instance)
220+
221+
def format(self, source):
222+
"""
223+
Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
224+
name for export).
225+
"""
226+
self._jwrite.format(source)
227+
return self
228+
229+
195230
@inherit_doc
196231
class MLWritable(object):
197232
"""
@@ -220,6 +255,17 @@ def write(self):
220255
return JavaMLWriter(self)
221256

222257

258+
@inherit_doc
259+
class GeneralJavaMLWritable(JavaMLWritable):
260+
"""
261+
(Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`.
262+
"""
263+
264+
def write(self):
265+
"""Returns an GeneralMLWriter instance for this ML instance."""
266+
return GeneralJavaMLWriter(self)
267+
268+
223269
@inherit_doc
224270
class MLReader(BaseReadWrite):
225271
"""

0 commit comments

Comments
 (0)