Skip to content
158 changes: 153 additions & 5 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
#

import sys
import os

if sys.version > '3':
basestring = str

from pyspark import since, keyword_only, SparkContext
from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.util import *

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do import pyspark.ml.util as mlutil?

@jkbradley jkbradley Aug 11, 2017

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK either way, though mlutil would be cleaner

from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import inherit_doc

Expand Down Expand Up @@ -130,13 +131,16 @@ def copy(self, extra=None):
@since("2.0.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
allStagesAreJava = SharedReadWrite.checkStagesForJava(self.getStages())
if allStagesAreJava:
return JavaMLWriter(self)
return PipelineWriter(self)

@classmethod
@since("2.0.0")
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
return PipelineReader(cls)

@classmethod
def _from_java(cls, java_stage):
Expand Down Expand Up @@ -171,6 +175,76 @@ def _to_java(self):
return _java_obj


@inherit_doc
class PipelineWriter(MLWriter):
"""
(Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types
"""

def __init__(self, instance):
super(PipelineWriter, self).__init__()
self.instance = instance

def saveImpl(self, path):
stages = self.instance.getStages()
SharedReadWrite.validateStages(stages)
SharedReadWrite.saveImpl(self.instance, stages, self.sc, path)


@inherit_doc
class PipelineReader(MLReader):
"""
(Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types
"""

def __init__(self, cls):
super(PipelineReader, self).__init__()
self.cls = cls

def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if 'savedAsPython' not in metadata['paramMap']:
return JavaMLReader(self.cls).load(path)
else:
uid, stages = SharedReadWrite.load(metadata, self.sc, path)
return Pipeline(stages=stages)._resetUid(uid)


@inherit_doc
class PipelineModelWriter(MLWriter):
"""
(Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types
"""

def __init__(self, instance):
super(PipelineModelWriter, self).__init__()
self.instance = instance

def saveImpl(self, path):
stages = self.instance.stages
SharedReadWrite.validateStages(stages)
SharedReadWrite.saveImpl(self.instance, stages, self.sc, path)


@inherit_doc
class PipelineModelReader(MLReader):
"""
(Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types
"""

def __init__(self, cls):
super(PipelineModelReader, self).__init__()
self.cls = cls

def load(self, path):
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
if 'savedAsPython' not in metadata['paramMap']:
return JavaMLReader(self.cls).load(path)
else:
uid, stages = SharedReadWrite.load(metadata, self.sc, path)
return PipelineModel(stages=stages)._resetUid(uid)


@inherit_doc
class PipelineModel(Model, MLReadable, MLWritable):
"""
Expand Down Expand Up @@ -204,13 +278,16 @@ def copy(self, extra=None):
@since("2.0.0")
def write(self):
"""Returns an MLWriter instance for this ML instance."""
return JavaMLWriter(self)
allStagesAreJava = SharedReadWrite.checkStagesForJava(self.stages)
if allStagesAreJava:
return JavaMLWriter(self)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the similar logic twice, can you move it to a util function ?

return PipelineModelWriter(self)

@classmethod
@since("2.0.0")
def read(cls):
"""Returns an MLReader instance for this class."""
return JavaMLReader(cls)
return PipelineModelReader(cls)

@classmethod
def _from_java(cls, java_stage):
Expand Down Expand Up @@ -242,3 +319,74 @@ def _to_java(self):
JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)

return _java_obj


@inherit_doc
class SharedReadWrite():

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to something with "Pipeline" such as "PipelineSharedReadWrite"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, note that it is either private or DeveloperApi

"""
Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between
:py:class:`Pipeline` and :py:class`PipelineModel`

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing colon


.. versionadded:: 2.3.0
"""

@staticmethod
def checkStagesForJava(stages):
allStagesAreJava = True

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying @MrBago 's comment here:
How about allStagesAreJava = all(isinstance(stage, JavaMLWritable) for stage in self.getStages())?

for stage in stages:
if not isinstance(stage, JavaMLWritable):
allStagesAreJava = False
break
return allStagesAreJava

@staticmethod
def validateStages(stages):
"""
Check that all stages are Writable
"""
for stage in stages:
if not isinstance(stage, MLWritable):
raise ValueError("Pipeline write will fail on this pipline " +

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: pipline

"because stage %s of type %s is not MLWritable",
stage.uid, type(stage))

@staticmethod
def saveImpl(instance, stages, sc, path):
"""
Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
- save metadata to path/metadata
- save stages to stages/IDX_UID
"""
stageUids = [stage.uid for stage in stages]
jsonParams = {'stageUids': stageUids, 'savedAsPython': True}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just occurred to me: For future extensibility, it would make sense to change this to something like 'language': 'Python' since there may be something analogous for R or other languages in the future.

DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
stagesDir = os.path.join(path, "stages")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here use os.path.join to generate full path, maybe will have some risk... because it depends on local OS path format.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley, what's the right way to handle Paths in pyspark? Scala has org.apache.hadoop.fs.Path, is there something similar in pyspark?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is as good as it gets, as far as I know

for index, stage in enumerate(stages):
stage.write().save(SharedReadWrite
.getStagePath(stage.uid, index, len(stages), stagesDir))

@staticmethod
def load(metadata, sc, path):
"""
Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`

:return: (UID, list of stages)
"""
stagesDir = os.path.join(path, "stages")
stageUids = metadata['paramMap']['stageUids']
stages = []
for index, stageUid in enumerate(stageUids):
stagePath = SharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir)
stage = DefaultParamsReader.loadParamsInstance(stagePath, sc)
stages.append(stage)
return (metadata['uid'], stages)

@staticmethod
def getStagePath(stageUid, stageIdx, numStages, stagesDir):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stageIdx isn't used by this method, is that intentional?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be used. Fixed this. Thanks!

"""
Get path for saving the given stage.
"""
stageIdxDigits = len(str(numStages))
stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid
stagePath = os.path.join(stagesDir, stageDir)
return stagePath
35 changes: 32 additions & 3 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _transform(self, dataset):
return dataset


class MockUnaryTransformer(UnaryTransformer):
class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):

shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
"data in a DataFrame",
Expand All @@ -150,7 +150,7 @@ def outputDataType(self):
def validateInputType(self, inputType):
if inputType != DoubleType():
raise TypeError("Bad input type: {}. ".format(inputType) +
"Requires Integer.")
"Requires Double.")


class MockEstimator(Estimator, HasFake):
Expand Down Expand Up @@ -1063,7 +1063,7 @@ def _compare_pipelines(self, m1, m2):
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
if isinstance(m1, JavaParams):
if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self._compare_params(m1, m2, p)
Expand Down Expand Up @@ -1142,6 +1142,35 @@ def test_nested_pipeline_persistence(self):
except OSError:
pass

def test_python_transformer_pipeline_persistence(self):
"""
Pipeline[MockUnaryTransformer, Binarizer]
"""
temp_path = tempfile.mkdtemp()

try:
df = self.spark.range(0, 10).toDF('input')
tf = MockUnaryTransformer(shiftVal=2)\
.setInputCol("input").setOutputCol("shiftedInput")
tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
pl = Pipeline(stages=[tf, tf2])
model = pl.fit(df)

pipeline_path = temp_path + "/pipeline"
pl.save(pipeline_path)
loaded_pipeline = Pipeline.load(pipeline_path)
self._compare_pipelines(pl, loaded_pipeline)

model_path = temp_path + "/pipeline-model"
model.save(model_path)
loaded_model = PipelineModel.load(model_path)
self._compare_pipelines(model, loaded_model)
finally:
try:
rmtree(temp_path)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this in a try block? I worry about silencing errors in tests because it's a good way to miss issues.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same pattern that exists in all other tests so I just followed it for this one.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I recall, it was because we didn't want tests to fail because of cleanup failing. I forget if/when cleanup failures were causing a problem...

except OSError:
pass

def test_onevsrest(self):
temp_path = tempfile.mkdtemp()
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
Expand Down