-
Notifications
You must be signed in to change notification settings - Fork 486
Replace sparkdl's ImageSchema with Spark2.3's version #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
adc132d
fd9a38f
537d1b1
5516b50
327ddc8
a5f2ff1
0ea3761
936838a
1ab34bf
c2e803b
a3f4d08
421c924
1a117b4
005fd61
10c182c
def1e0e
5ef9a6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,5 @@ | ||
| // You may use this file to add plugin dependencies for sbt. | ||
| resolvers += "Spark Packages repo" at "https://dl.bintray.com/spark-packages/maven/" | ||
|
|
||
| addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.5") | ||
|
|
||
| // scalacOptions in (Compile,doc) := Seq("-groups", "-implicits") | ||
|
|
||
| addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| # NOTE: This file is copied from Spark2.3 in order to be able to use this in allready released spark versions. | ||
| # TODO: remove this when Spark 2.3 is out! | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's put TODO comments at the top of these copied files related to ImageSchema that it should be removed once Spark 2.3 is out. and that these are copied from Spark and should not be modified.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's a good idea. |
||
| """ | ||
| .. attribute:: ImageSchema | ||
|
|
||
| An attribute of this module that contains the instance of :class:`_ImageSchema`. | ||
|
|
||
| .. autoclass:: _ImageSchema | ||
| :members: | ||
| """ | ||
|
|
||
| import numpy as np | ||
| from pyspark import SparkContext | ||
| from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string | ||
| from pyspark.sql import DataFrame, SparkSession | ||
|
|
||
|
|
||
| class _ImageSchema(object): | ||
| """ | ||
| Internal class for `pyspark.ml.image.ImageSchema` attribute. Meant to be private and | ||
| not to be instantized. Use `pyspark.ml.image.ImageSchema` attribute to access the | ||
| APIs of this class. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self._imageSchema = None | ||
| self._ocvTypes = None | ||
| self._imageFields = None | ||
| self._undefinedImageType = None | ||
|
|
||
| @property | ||
| def imageSchema(self): | ||
| """ | ||
| Returns the image schema. | ||
|
|
||
| :return: a :class:`StructType` with a single column of images | ||
| named "image" (nullable). | ||
|
|
||
| .. versionadded:: 2.3.0 | ||
| """ | ||
|
|
||
| if self._imageSchema is None: | ||
| ctx = SparkContext._active_spark_context | ||
| jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema() | ||
| self._imageSchema = _parse_datatype_json_string(jschema.json()) | ||
| return self._imageSchema | ||
|
|
||
| @property | ||
| def ocvTypes(self): | ||
| """ | ||
| Returns the OpenCV type mapping supported. | ||
|
|
||
| :return: a dictionary containing the OpenCV type mapping supported. | ||
|
|
||
| .. versionadded:: 2.3.0 | ||
| """ | ||
|
|
||
| if self._ocvTypes is None: | ||
| ctx = SparkContext._active_spark_context | ||
| self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) | ||
| return self._ocvTypes | ||
|
|
||
| @property | ||
| def imageFields(self): | ||
| """ | ||
| Returns field names of image columns. | ||
|
|
||
| :return: a list of field names. | ||
|
|
||
| .. versionadded:: 2.3.0 | ||
| """ | ||
|
|
||
| if self._imageFields is None: | ||
| ctx = SparkContext._active_spark_context | ||
| self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields()) | ||
| return self._imageFields | ||
|
|
||
| @property | ||
| def undefinedImageType(self): | ||
| """ | ||
| Returns the name of undefined image type for the invalid image. | ||
|
|
||
| .. versionadded:: 2.3.0 | ||
| """ | ||
|
|
||
| if self._undefinedImageType is None: | ||
| ctx = SparkContext._active_spark_context | ||
| self._undefinedImageType = \ | ||
| ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() | ||
| return self._undefinedImageType | ||
|
|
||
| def toNDArray(self, image): | ||
| """ | ||
| Converts an image to an array with metadata. | ||
|
|
||
| :param `Row` image: A row that contains the image to be converted. It should | ||
| have the attributes specified in `ImageSchema.imageSchema`. | ||
| :return: a `numpy.ndarray` that is an image. | ||
|
|
||
| .. versionadded:: 2.3.0 | ||
| """ | ||
|
|
||
| if not isinstance(image, Row): | ||
| raise TypeError( | ||
| "image argument should be pyspark.sql.types.Row; however, " | ||
| "it got [%s]." % type(image)) | ||
|
|
||
| if any(not hasattr(image, f) for f in self.imageFields): | ||
| raise ValueError( | ||
| "image argument should have attributes specified in " | ||
| "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields)) | ||
|
|
||
| height = image.height | ||
| width = image.width | ||
| nChannels = image.nChannels | ||
| return np.ndarray( | ||
| shape=(height, width, nChannels), | ||
| dtype=np.uint8, | ||
| buffer=image.data, | ||
| strides=(width * nChannels, nChannels, 1)) | ||
|
|
||
| def toImage(self, array, origin=""): | ||
| """ | ||
| Converts an array with metadata to a two-dimensional image. | ||
|
|
||
| :param `numpy.ndarray` array: The array to convert to image. | ||
| :param str origin: Path to the image, optional. | ||
| :return: a :class:`Row` that is a two dimensional image. | ||
|
|
||
| .. versionadded:: 2.3.0 | ||
| """ | ||
|
|
||
| if not isinstance(array, np.ndarray): | ||
| raise TypeError( | ||
| "array argument should be numpy.ndarray; however, it got [%s]." % type(array)) | ||
|
|
||
| if array.ndim != 3: | ||
| raise ValueError("Invalid array shape") | ||
|
|
||
| height, width, nChannels = array.shape | ||
| ocvTypes = ImageSchema.ocvTypes | ||
| if nChannels == 1: | ||
| mode = ocvTypes["CV_8UC1"] | ||
| elif nChannels == 3: | ||
| mode = ocvTypes["CV_8UC3"] | ||
| elif nChannels == 4: | ||
| mode = ocvTypes["CV_8UC4"] | ||
| else: | ||
| raise ValueError("Invalid number of channels") | ||
|
|
||
| # Running `bytearray(numpy.array([1]))` fails in specific Python versions | ||
| # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3. | ||
| # Here, it avoids it by converting it to bytes. | ||
| data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) | ||
|
|
||
| # Creating new Row with _create_row(), because Row(name = value, ... ) | ||
| # orders fields by name, which conflicts with expected schema order | ||
| # when the new DataFrame is created by UDF | ||
| return _create_row(self.imageFields, | ||
| [origin, height, width, nChannels, mode, data]) | ||
|
|
||
| def readImages(self, path, recursive=False, numPartitions=-1, | ||
| dropImageFailures=False, sampleRatio=1.0, seed=0): | ||
| """ | ||
| Reads the directory of images from the local or remote source. | ||
|
|
||
| .. note:: If multiple jobs are run in parallel with different sampleRatio or recursive flag, | ||
| there may be a race condition where one job overwrites the hadoop configs of another. | ||
|
|
||
| .. note:: If sample ratio is less than 1, sampling uses a PathFilter that is efficient but | ||
| potentially non-deterministic. | ||
|
|
||
| :param str path: Path to the image directory. | ||
| :param bool recursive: Recursive search flag. | ||
| :param int numPartitions: Number of DataFrame partitions. | ||
| :param bool dropImageFailures: Drop the files that are not valid images. | ||
| :param float sampleRatio: Fraction of the images loaded. | ||
| :param int seed: Random number seed. | ||
| :return: a :class:`DataFrame` with a single column of "images", | ||
| see ImageSchema for details. | ||
|
|
||
| >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) | ||
| >>> df.count() | ||
| 4 | ||
|
|
||
| .. versionadded:: 2.3.0 | ||
| """ | ||
|
|
||
| ctx = SparkContext._active_spark_context | ||
| spark = SparkSession(ctx) | ||
| image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema | ||
| jsession = spark._jsparkSession | ||
| jresult = image_schema.readImages(path, jsession, recursive, numPartitions, | ||
| dropImageFailures, float(sampleRatio), seed) | ||
| return DataFrame(jresult, spark._wrapped) | ||
|
|
||
|
|
||
| ImageSchema = _ImageSchema() | ||
|
|
||
|
|
||
| # Monkey patch to disallow instantization of this class. | ||
| def _disallow_instance(_): | ||
| raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") | ||
| _ImageSchema.__init__ = _disallow_instance | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,8 +13,18 @@ | |
| # limitations under the License. | ||
| # | ||
|
|
||
|
|
||
| # hack to import copy-pasted image schema (to be removed in Spark2.3) | ||
| # TODO remove in Spark2.3 | ||
| import os | ||
| import pyspark.ml | ||
| dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
| parentdir = os.path.dirname(dir_path) | ||
| pyspark.ml.__path__.append(os.path.join(parentdir, "pyspark", "ml")) | ||
|
|
||
| from pyspark.ml.image import ImageSchema | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move above this group to its own group |
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm concerned that monkey patching the spark.ml might introduce unexpected behaviour for the user. Specifically I'm worried that this will work? but this will not: The alternative is to put this in I don't have a preference here, I just want to make sure both options are considered.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My take is we have two options here: I have slight preference for 1 since it has the nice property that your script will keep working once you get it work first time.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we simply add
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah ok, that works.I can make the change.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I still find it confusing that you end up in a state that you can mix imports of sparkdl.image.ImageSchema and some pyspark.ml.image.ImageSchema and they are actually the same thing. It might be better than the monkey patch, however the monkey patch is only temporary while this solution would persist. |
||
| from .graph.input import TFInputGraph | ||
| from .image.imageIO import imageSchema, imageType, readImages | ||
| from .transformers.keras_image import KerasImageFileTransformer | ||
| from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer | ||
| from .transformers.tf_image import TFImageTransformer | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,6 @@ | |
| import tensorflow as tf | ||
|
|
||
| from sparkdl.graph.builder import IsolatedSession | ||
| from sparkdl.image.imageIO import SparkMode | ||
|
|
||
| logger = logging.getLogger('sparkdl') | ||
|
|
||
|
|
@@ -48,14 +47,13 @@ def buildSpImageConverter(img_dtype): | |
| # This is the default behavior of Python Image Library | ||
| shape = tf.reshape(tf.stack([height, width, num_channels], axis=0), | ||
| shape=(3,), name='shape') | ||
| if img_dtype == SparkMode.RGB: | ||
| if img_dtype == 'uint8': | ||
| image_uint8 = tf.decode_raw(image_buffer, tf.uint8, name="decode_raw") | ||
| image_float = tf.to_float(image_uint8) | ||
| else: | ||
| assert img_dtype == SparkMode.RGB_FLOAT32, \ | ||
| "Unsupported dtype for image: {}".format(img_dtype) | ||
| elif img_dtype == 'float32': | ||
| image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw") | ||
|
|
||
| else: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the new schema, are there legitimate types that have float64 (or any other dtypes) as img_dtype?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK the schema does not specify types. It only specifies a field with OpenCv type number. There are open CV types which have float64. Technically the schema includes openCvTypes map with only a subset of types, however we already need types outside of this subset (Tf produced images are stored as float32)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So does ImageSchema support OpenCV types that have float64? If so, should we support them here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently as far as I know there is no way how you can get float64 image. ImageSchema as a data format supports it in that it has a mode field which is supposed to have OpenCV type in it and there are OpenCV types with float64. However, it is not listed in the list of openCV types in their scala code (and neither are any float32 which we need) and as it stands now, readImages can only ever produce images stored in unsigned bytes (both scala an PIL version) so one of CV_8U* formats. We also need the float32 formats since thats' what we return when returning images from TF so I added those to our python side. The python code from image schema can only handle unsigned byte images, thats why I use our own version in imageIO (imageArrayToStruct and imageStructToArray).
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From offline discussion: The ImageSchema utilities in Spark only support uint8 types. Ideally float32 types would also be supported natively in Spark so we don't have to have special logic in this package to handle it. We'll create a Jira in Spark for that and try to address it there.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have a Jira for this already? If so, could you link from here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes we do. https://issues.apache.org/jira/browse/SPARK-22730 You mean you want it in the code? That would probably go to imageIO, I'll put it there |
||
| raise ValueError('unsupported image data type "%s", currently only know how to handle uint8 and float32' % img_dtype) | ||
| image_reshaped = tf.reshape(image_float, shape, name="reshaped") | ||
| image_input = tf.expand_dims(image_reshaped, 0, name="image_input") | ||
| gfn = issn.asGraphFunction([height, width, image_buffer, num_channels], [image_input]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allready -> already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: probably needs to go under the license stuff