Skip to content
156 changes: 151 additions & 5 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
#

import sys
import os

if sys.version > '3':
basestring = str

from pyspark import since, keyword_only, SparkContext
from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.util import *
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do import pyspark.ml.util as mlutil?

Copy link
Copy Markdown
Member

@jkbradley jkbradley Aug 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK either way, though mlutil would be cleaner

from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import inherit_doc

Expand Down Expand Up @@ -130,13 +131,16 @@ def copy(self, extra=None):
@since("2.0.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages())
if allStagesAreJava:
return JavaMLWriter(self)
return PipelineWriter(self)

@classmethod
@since("2.0.0")
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
return PipelineReader(cls)

@classmethod
def _from_java(cls, java_stage):
Expand Down Expand Up @@ -171,6 +175,76 @@ def _to_java(self):
return _java_obj


@inherit_doc
class PipelineWriter(MLWriter):
"""
(Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types
"""

def __init__(self, instance):
super(PipelineWriter, self).__init__()
self.instance = instance

def saveImpl(self, path):
stages = self.instance.getStages()
PipelineSharedReadWrite.validateStages(stages)
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)


@inherit_doc
class PipelineReader(MLReader):
"""
(Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types
"""

def __init__(self, cls):
super(PipelineReader, self).__init__()
self.cls = cls

def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
return JavaMLReader(self.cls).load(path)
else:
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
return Pipeline(stages=stages)._resetUid(uid)


@inherit_doc
class PipelineModelWriter(MLWriter):
"""
(Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types
"""

def __init__(self, instance):
super(PipelineModelWriter, self).__init__()
self.instance = instance

def saveImpl(self, path):
stages = self.instance.stages
PipelineSharedReadWrite.validateStages(stages)
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)


@inherit_doc
class PipelineModelReader(MLReader):
"""
(Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types
"""

def __init__(self, cls):
super(PipelineModelReader, self).__init__()
self.cls = cls

def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
return JavaMLReader(self.cls).load(path)
else:
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
return PipelineModel(stages=stages)._resetUid(uid)


@inherit_doc
class PipelineModel(Model, MLReadable, MLWritable):
"""
Expand Down Expand Up @@ -204,13 +278,16 @@ def copy(self, extra=None):
@since("2.0.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages)
if allStagesAreJava:
return JavaMLWriter(self)
return PipelineModelWriter(self)

@classmethod
@since("2.0.0")
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
return PipelineModelReader(cls)

@classmethod
def _from_java(cls, java_stage):
Expand Down Expand Up @@ -242,3 +319,72 @@ def _to_java(self):
JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)

return _java_obj


@inherit_doc
class PipelineSharedReadWrite():
"""
.. note:: DeveloperApi

Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between
:py:class:`Pipeline` and :py:class:`PipelineModel`

.. versionadded:: 2.3.0
"""

@staticmethod
def checkStagesForJava(stages):
return all(isinstance(stage, JavaMLWritable) for stage in stages)

@staticmethod
def validateStages(stages):
"""
Check that all stages are Writable
"""
for stage in stages:
if not isinstance(stage, MLWritable):
raise ValueError("Pipeline write will fail on this pipeline " +
"because stage %s of type %s is not MLWritable",
stage.uid, type(stage))

@staticmethod
def saveImpl(instance, stages, sc, path):
"""
Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
- save metadata to path/metadata
- save stages to stages/IDX_UID
"""
stageUids = [stage.uid for stage in stages]
jsonParams = {'stageUids': stageUids, 'language': 'Python'}
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
stagesDir = os.path.join(path, "stages")
for index, stage in enumerate(stages):
stage.write().save(PipelineSharedReadWrite
.getStagePath(stage.uid, index, len(stages), stagesDir))

@staticmethod
def load(metadata, sc, path):
"""
Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`

:return: (UID, list of stages)
"""
stagesDir = os.path.join(path, "stages")
stageUids = metadata['paramMap']['stageUids']
stages = []
for index, stageUid in enumerate(stageUids):
stagePath = \
PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir)
stage = DefaultParamsReader.loadParamsInstance(stagePath, sc)
stages.append(stage)
return (metadata['uid'], stages)

@staticmethod
def getStagePath(stageUid, stageIdx, numStages, stagesDir):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stageIdx isn't used by this method, is that intentional?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be used. Fixed this. Thanks!

"""
Get path for saving the given stage.
"""
stageIdxDigits = len(str(numStages))
stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid
stagePath = os.path.join(stagesDir, stageDir)
return stagePath
35 changes: 32 additions & 3 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _transform(self, dataset):
return dataset


class MockUnaryTransformer(UnaryTransformer):
class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):

shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
"data in a DataFrame",
Expand All @@ -150,7 +150,7 @@ def outputDataType(self):
def validateInputType(self, inputType):
if inputType != DoubleType():
raise TypeError("Bad input type: {}. ".format(inputType) +
"Requires Integer.")
"Requires Double.")


class MockEstimator(Estimator, HasFake):
Expand Down Expand Up @@ -1063,7 +1063,7 @@ def _compare_pipelines(self, m1, m2):
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
if isinstance(m1, JavaParams):
if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self._compare_params(m1, m2, p)
Expand Down Expand Up @@ -1142,6 +1142,35 @@ def test_nested_pipeline_persistence(self):
except OSError:
pass

def test_python_transformer_pipeline_persistence(self):
"""
Pipeline[MockUnaryTransformer, Binarizer]
"""
temp_path = tempfile.mkdtemp()

try:
df = self.spark.range(0, 10).toDF('input')
tf = MockUnaryTransformer(shiftVal=2)\
.setInputCol("input").setOutputCol("shiftedInput")
tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
pl = Pipeline(stages=[tf, tf2])
model = pl.fit(df)

pipeline_path = temp_path + "/pipeline"
pl.save(pipeline_path)
loaded_pipeline = Pipeline.load(pipeline_path)
self._compare_pipelines(pl, loaded_pipeline)

model_path = temp_path + "/pipeline-model"
model.save(model_path)
loaded_model = PipelineModel.load(model_path)
self._compare_pipelines(model, loaded_model)
finally:
try:
rmtree(temp_path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this in a try block? I worry about silencing errors in tests because it's a good way to miss issues.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same pattern that exists in all other tests so I just followed it for this one.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I recall, it was because we didn't want tests to fail because of cleanup failing. I forget if/when cleanup failures were causing a problem...

except OSError:
pass

def test_onevsrest(self):
temp_path = tempfile.mkdtemp()
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
Expand Down