[Spark-17025][ML][Python] Persistence for Pipelines with Python-only Stages#18888
[Spark-17025][ML][Python] Persistence for Pipelines with Python-only Stages#18888ajaysaini725 wants to merge 11 commits intoapache:masterfrom
Conversation
|
@jkbradley @MrBago @WeichenXu123 Can you please review this? |
|
Test build #80421 has finished for PR 18888 at commit
|
|
Test build #80426 has finished for PR 18888 at commit
|
|
Test build #80427 has finished for PR 18888 at commit
|
| if not isinstance(stage, JavaMLWritable): | ||
| allStagesAreJava = False | ||
| if allStagesAreJava: | ||
| return JavaMLWriter(self) |
There was a problem hiding this comment.
I find the similar logic twice, can you move it to a util function ?
| stageUids = [stage.uid for stage in stages] | ||
| jsonParams = {'stageUids': stageUids, 'savedAsPython': True} | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) | ||
| stagesDir = os.path.join(path, "stages") |
There was a problem hiding this comment.
Here use os.path.join to generate full path, maybe will have some risk... because it depends on local OS path format.
There was a problem hiding this comment.
@jkbradley, what's the right way to handle Paths in pyspark? Scala has org.apache.hadoop.fs.Path, is there something similar in pyspark?
There was a problem hiding this comment.
This is as good as it gets, as far as I know
|
Test build #80465 has finished for PR 18888 at commit
|
MrBago
left a comment
There was a problem hiding this comment.
This looks good @ajaysaini725, mostly minor comments.
One more major concern I have is with __get_class in DefaultParamsReader. I know it's outside this PR, but it looks a little brittle on first pass. Are we testing this method with different module structures and notebooks?
| from pyspark.ml.base import Estimator, Model, Transformer | ||
| from pyspark.ml.param import Param, Params | ||
| from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable | ||
| from pyspark.ml.util import * |
There was a problem hiding this comment.
can we do import pyspark.ml.util as mlutil?
There was a problem hiding this comment.
I'm OK either way, though mlutil would be cleaner
| stages = self.getStages() | ||
| for stage in stages: | ||
| if not isinstance(stage, JavaMLWritable): | ||
| allStagesAreJava = False |
There was a problem hiding this comment.
How about allStagesAreJava = all(isinstance(stage, JavaMLWritable) for stage in self.getStages())
| return (metadata['uid'], stages) | ||
|
|
||
| @staticmethod | ||
| def getStagePath(stageUid, stageIdx, numStages, stagesDir): |
There was a problem hiding this comment.
stageIdx isn't used by this method, is that intentional?
There was a problem hiding this comment.
It should be used. Fixed this. Thanks!
| self._compare_pipelines(model, loaded_model) | ||
| finally: | ||
| try: | ||
| rmtree(temp_path) |
There was a problem hiding this comment.
Why do we need this in a try block? I worry about silencing errors in tests because it's a good way to miss issues.
There was a problem hiding this comment.
This is the same pattern that exists in all other tests so I just followed it for this one.
There was a problem hiding this comment.
As I recall, it was because we didn't want tests to fail because of cleanup failing. I forget if/when cleanup failures were causing a problem...
| stageUids = [stage.uid for stage in stages] | ||
| jsonParams = {'stageUids': stageUids, 'savedAsPython': True} | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) | ||
| stagesDir = os.path.join(path, "stages") |
There was a problem hiding this comment.
@jkbradley, what's the right way to handle Paths in pyspark? Scala has org.apache.hadoop.fs.Path, is there something similar in pyspark?
|
Test build #80513 has finished for PR 18888 at commit
|
| """ | ||
| for stage in stages: | ||
| if not isinstance(stage, MLWritable): | ||
| raise ValueError("Pipeline write will fail on this pipline " + |
|
|
||
|
|
||
| @inherit_doc | ||
| class SharedReadWrite(): |
There was a problem hiding this comment.
Rename to something with "Pipeline" such as "PipelineSharedReadWrite"
There was a problem hiding this comment.
Also, note that it is either private or DeveloperApi
| class SharedReadWrite(): | ||
| """ | ||
| Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between | ||
| :py:class:`Pipeline` and :py:class`PipelineModel` |
| - save stages to stages/IDX_UID | ||
| """ | ||
| stageUids = [stage.uid for stage in stages] | ||
| jsonParams = {'stageUids': stageUids, 'savedAsPython': True} |
There was a problem hiding this comment.
It just occurred to me: For future extensibility, it would make sense to change this to something like 'language': 'Python' since there may be something analogous for R or other languages in the future.
| self._compare_pipelines(model, loaded_model) | ||
| finally: | ||
| try: | ||
| rmtree(temp_path) |
There was a problem hiding this comment.
As I recall, it was because we didn't want tests to fail because of cleanup failing. I forget if/when cleanup failures were causing a problem...
| stageUids = [stage.uid for stage in stages] | ||
| jsonParams = {'stageUids': stageUids, 'savedAsPython': True} | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) | ||
| stagesDir = os.path.join(path, "stages") |
There was a problem hiding this comment.
This is as good as it gets, as far as I know
| from pyspark.ml.base import Estimator, Model, Transformer | ||
| from pyspark.ml.param import Param, Params | ||
| from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable | ||
| from pyspark.ml.util import * |
There was a problem hiding this comment.
I'm OK either way, though mlutil would be cleaner
|
|
||
| @staticmethod | ||
| def checkStagesForJava(stages): | ||
| allStagesAreJava = True |
There was a problem hiding this comment.
Copying @MrBago 's comment here:
How about allStagesAreJava = all(isinstance(stage, JavaMLWritable) for stage in self.getStages())?
|
Test build #80547 has finished for PR 18888 at commit
|
|
LGTM pending tests! |
|
Test build #80548 has finished for PR 18888 at commit
|
|
@jkbradley Quick reminder to merge this since the tests have passed! |
What changes were proposed in this pull request?
Implemented a Python-only persistence framework for pipelines containing stages that cannot be saved using Java.
How was this patch tested?
Created a custom Python-only UnaryTransformer, included it in a Pipeline, and saved/loaded the pipeline. The loaded pipeline was compared against the original using _compare_pipelines() in tests.py.