-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[Spark-17025][ML][Python] Persistence for Pipelines with Python-only Stages #18888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
840a193
85a98d6
0eb0494
ba4402c
22ebe3e
6a094f0
4d2caf8
cdcd1cc
cf1a08d
18c902c
2b63eea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,14 +16,15 @@ | |
| # | ||
|
|
||
| import sys | ||
| import os | ||
|
|
||
| if sys.version > '3': | ||
| basestring = str | ||
|
|
||
| from pyspark import since, keyword_only, SparkContext | ||
| from pyspark.ml.base import Estimator, Model, Transformer | ||
| from pyspark.ml.param import Param, Params | ||
| from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable | ||
| from pyspark.ml.util import * | ||
| from pyspark.ml.wrapper import JavaParams | ||
| from pyspark.ml.common import inherit_doc | ||
|
|
||
|
|
@@ -130,13 +131,16 @@ def copy(self, extra=None): | |
| @since("2.0.0") | ||
| def write(self): | ||
| """Returns an MLWriter instance for this ML instance.""" | ||
| return JavaMLWriter(self) | ||
| allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages()) | ||
| if allStagesAreJava: | ||
| return JavaMLWriter(self) | ||
| return PipelineWriter(self) | ||
|
|
||
| @classmethod | ||
| @since("2.0.0") | ||
| def read(cls): | ||
| """Returns an MLReader instance for this class.""" | ||
| return JavaMLReader(cls) | ||
| return PipelineReader(cls) | ||
|
|
||
| @classmethod | ||
| def _from_java(cls, java_stage): | ||
|
|
@@ -171,6 +175,76 @@ def _to_java(self): | |
| return _java_obj | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class PipelineWriter(MLWriter): | ||
| """ | ||
| (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types | ||
| """ | ||
|
|
||
| def __init__(self, instance): | ||
| super(PipelineWriter, self).__init__() | ||
| self.instance = instance | ||
|
|
||
| def saveImpl(self, path): | ||
| stages = self.instance.getStages() | ||
| PipelineSharedReadWrite.validateStages(stages) | ||
| PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class PipelineReader(MLReader): | ||
| """ | ||
| (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types | ||
| """ | ||
|
|
||
| def __init__(self, cls): | ||
| super(PipelineReader, self).__init__() | ||
| self.cls = cls | ||
|
|
||
| def load(self, path): | ||
| metadata = DefaultParamsReader.loadMetadata(path, self.sc) | ||
| if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': | ||
| return JavaMLReader(self.cls).load(path) | ||
| else: | ||
| uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) | ||
| return Pipeline(stages=stages)._resetUid(uid) | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class PipelineModelWriter(MLWriter): | ||
| """ | ||
| (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types | ||
| """ | ||
|
|
||
| def __init__(self, instance): | ||
| super(PipelineModelWriter, self).__init__() | ||
| self.instance = instance | ||
|
|
||
| def saveImpl(self, path): | ||
| stages = self.instance.stages | ||
| PipelineSharedReadWrite.validateStages(stages) | ||
| PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class PipelineModelReader(MLReader): | ||
| """ | ||
| (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types | ||
| """ | ||
|
|
||
| def __init__(self, cls): | ||
| super(PipelineModelReader, self).__init__() | ||
| self.cls = cls | ||
|
|
||
| def load(self, path): | ||
| metadata = DefaultParamsReader.loadMetadata(path, self.sc) | ||
| if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': | ||
| return JavaMLReader(self.cls).load(path) | ||
| else: | ||
| uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) | ||
| return PipelineModel(stages=stages)._resetUid(uid) | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class PipelineModel(Model, MLReadable, MLWritable): | ||
| """ | ||
|
|
@@ -204,13 +278,16 @@ def copy(self, extra=None): | |
| @since("2.0.0") | ||
| def write(self): | ||
| """Returns an MLWriter instance for this ML instance.""" | ||
| return JavaMLWriter(self) | ||
| allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages) | ||
| if allStagesAreJava: | ||
| return JavaMLWriter(self) | ||
| return PipelineModelWriter(self) | ||
|
|
||
| @classmethod | ||
| @since("2.0.0") | ||
| def read(cls): | ||
| """Returns an MLReader instance for this class.""" | ||
| return JavaMLReader(cls) | ||
| return PipelineModelReader(cls) | ||
|
|
||
| @classmethod | ||
| def _from_java(cls, java_stage): | ||
|
|
@@ -242,3 +319,72 @@ def _to_java(self): | |
| JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) | ||
|
|
||
| return _java_obj | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class PipelineSharedReadWrite(): | ||
| """ | ||
| .. note:: DeveloperApi | ||
|
|
||
| Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between | ||
| :py:class:`Pipeline` and :py:class:`PipelineModel` | ||
|
|
||
| .. versionadded:: 2.3.0 | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def checkStagesForJava(stages): | ||
| return all(isinstance(stage, JavaMLWritable) for stage in stages) | ||
|
|
||
| @staticmethod | ||
| def validateStages(stages): | ||
| """ | ||
| Check that all stages are Writable | ||
| """ | ||
| for stage in stages: | ||
| if not isinstance(stage, MLWritable): | ||
| raise ValueError("Pipeline write will fail on this pipeline " + | ||
| "because stage %s of type %s is not MLWritable", | ||
| stage.uid, type(stage)) | ||
|
|
||
| @staticmethod | ||
| def saveImpl(instance, stages, sc, path): | ||
| """ | ||
| Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` | ||
| - save metadata to path/metadata | ||
| - save stages to stages/IDX_UID | ||
| """ | ||
| stageUids = [stage.uid for stage in stages] | ||
| jsonParams = {'stageUids': stageUids, 'language': 'Python'} | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) | ||
| stagesDir = os.path.join(path, "stages") | ||
| for index, stage in enumerate(stages): | ||
| stage.write().save(PipelineSharedReadWrite | ||
| .getStagePath(stage.uid, index, len(stages), stagesDir)) | ||
|
|
||
| @staticmethod | ||
| def load(metadata, sc, path): | ||
| """ | ||
| Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` | ||
|
|
||
| :return: (UID, list of stages) | ||
| """ | ||
| stagesDir = os.path.join(path, "stages") | ||
| stageUids = metadata['paramMap']['stageUids'] | ||
| stages = [] | ||
| for index, stageUid in enumerate(stageUids): | ||
| stagePath = \ | ||
| PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir) | ||
| stage = DefaultParamsReader.loadParamsInstance(stagePath, sc) | ||
| stages.append(stage) | ||
| return (metadata['uid'], stages) | ||
|
|
||
| @staticmethod | ||
| def getStagePath(stageUid, stageIdx, numStages, stagesDir): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be used. Fixed this. Thanks! |
||
| """ | ||
| Get path for saving the given stage. | ||
| """ | ||
| stageIdxDigits = len(str(numStages)) | ||
| stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid | ||
| stagePath = os.path.join(stagesDir, stageDir) | ||
| return stagePath | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -123,7 +123,7 @@ def _transform(self, dataset): | |
| return dataset | ||
|
|
||
|
|
||
| class MockUnaryTransformer(UnaryTransformer): | ||
| class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): | ||
|
|
||
| shift = Param(Params._dummy(), "shift", "The amount by which to shift " + | ||
| "data in a DataFrame", | ||
|
|
@@ -150,7 +150,7 @@ def outputDataType(self): | |
| def validateInputType(self, inputType): | ||
| if inputType != DoubleType(): | ||
| raise TypeError("Bad input type: {}. ".format(inputType) + | ||
| "Requires Integer.") | ||
| "Requires Double.") | ||
|
|
||
|
|
||
| class MockEstimator(Estimator, HasFake): | ||
|
|
@@ -1063,7 +1063,7 @@ def _compare_pipelines(self, m1, m2): | |
| """ | ||
| self.assertEqual(m1.uid, m2.uid) | ||
| self.assertEqual(type(m1), type(m2)) | ||
| if isinstance(m1, JavaParams): | ||
| if isinstance(m1, JavaParams) or isinstance(m1, Transformer): | ||
| self.assertEqual(len(m1.params), len(m2.params)) | ||
| for p in m1.params: | ||
| self._compare_params(m1, m2, p) | ||
|
|
@@ -1142,6 +1142,35 @@ def test_nested_pipeline_persistence(self): | |
| except OSError: | ||
| pass | ||
|
|
||
| def test_python_transformer_pipeline_persistence(self): | ||
| """ | ||
| Pipeline[MockUnaryTransformer, Binarizer] | ||
| """ | ||
| temp_path = tempfile.mkdtemp() | ||
|
|
||
| try: | ||
| df = self.spark.range(0, 10).toDF('input') | ||
| tf = MockUnaryTransformer(shiftVal=2)\ | ||
| .setInputCol("input").setOutputCol("shiftedInput") | ||
| tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") | ||
| pl = Pipeline(stages=[tf, tf2]) | ||
| model = pl.fit(df) | ||
|
|
||
| pipeline_path = temp_path + "/pipeline" | ||
| pl.save(pipeline_path) | ||
| loaded_pipeline = Pipeline.load(pipeline_path) | ||
| self._compare_pipelines(pl, loaded_pipeline) | ||
|
|
||
| model_path = temp_path + "/pipeline-model" | ||
| model.save(model_path) | ||
| loaded_model = PipelineModel.load(model_path) | ||
| self._compare_pipelines(model, loaded_model) | ||
| finally: | ||
| try: | ||
| rmtree(temp_path) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this in a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same pattern that exists in all other tests so I just followed it for this one.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I recall, it was because we didn't want tests to fail because of cleanup failing. I forget if/when cleanup failures were causing a problem... |
||
| except OSError: | ||
| pass | ||
|
|
||
| def test_onevsrest(self): | ||
| temp_path = tempfile.mkdtemp() | ||
| df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we do
import pyspark.ml.util as mlutil?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK either way, though mlutil would be cleaner