diff --git a/notebooks/samples/302 - Pipeline Image Transformations.ipynb b/notebooks/samples/302 - Pipeline Image Transformations.ipynb index 06996a41719..db143ba457f 100644 --- a/notebooks/samples/302 - Pipeline Image Transformations.ipynb +++ b/notebooks/samples/302 - Pipeline Image Transformations.ipynb @@ -64,10 +64,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "mml-deploy": "hdinsight", - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "IMAGE_PATH = \"/datasets/CIFAR10/test\"" @@ -97,6 +94,64 @@ "print(images.count())" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also alternatively stream the images with a similiar api.\n", + "Check the [Structured Streaming Programming Guide](", + "https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html)\n", + "for more details on streaming." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "imageStream = spark.streamImages(IMAGE_PATH + \"/*\", sampleRatio = 0.1)\n", + "query = imageStream.select(\"image.height\").writeStream.format(\"memory\").queryName(\"heights\").start()\n", + "print(\"Streaming query activity: {}\".format(query.isActive))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Wait a few seconds and then try querying for the images below.\n", + "Note that when streaming a directory of images that already exists it will\n", + "consume all images in a single batch. If one were to move images into the\n", + "directory, the streaming engine would pick up on them and send them as\n", + "another batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heights = spark.sql(\"select * from heights\")\n", + "print(\"Streamed {} heights\".format(heights.count()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After we have streamed the images we can stop the query:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query.stop()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -208,6 +263,13 @@ "print(type(vector))\n", "len(vector.toArray())" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -227,7 +289,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.2" + "version": "3.6.1" } }, "nbformat": 4, diff --git a/src/cntk-model/src/main/scala/CNTKModel.scala b/src/cntk-model/src/main/scala/CNTKModel.scala index 8580c4406f0..f42b139097e 100644 --- a/src/cntk-model/src/main/scala/CNTKModel.scala +++ b/src/cntk-model/src/main/scala/CNTKModel.scala @@ -13,6 +13,7 @@ import org.apache.spark.ml.linalg.{DenseVector, Vectors} import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -245,14 +246,14 @@ class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexP val inputType = df.schema($(inputCol)).dataType val broadcastedModel = broadcastedModelOption.getOrElse(spark.sparkContext.broadcast(getModel)) - val rdd = df.rdd.mapPartitions( + val encoder = RowEncoder(df.schema.add(StructField(getOutputCol, VectorType))) + val output = df.mapPartitions( CNTKModelUtils.applyModel(selectedIndex, broadcastedModel, getMiniBatchSize, getInputNode, get(outputNodeName), - get(outputNodeIndex))) - val output = spark.createDataFrame(rdd, df.schema.add(StructField(getOutputCol, VectorType))) + get(outputNodeIndex)))(encoder) coersionOptionUDF match { case Some(_) => output.drop(coercedCol) diff --git a/src/core/env/src/main/scala/StreamUtilities.scala b/src/core/env/src/main/scala/StreamUtilities.scala index 20f14b90aec..1c82c2d5671 100644 --- a/src/core/env/src/main/scala/StreamUtilities.scala +++ b/src/core/env/src/main/scala/StreamUtilities.scala @@ -3,10 +3,12 @@ package com.microsoft.ml.spark -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, InputStream} import java.util.zip.ZipInputStream + import org.apache.commons.io.IOUtils import org.apache.spark.input.PortableDataStream + import scala.util.Random object StreamUtilities { @@ -34,24 +36,17 @@ object StreamUtilities { /** Iterate through the entries of a streamed .zip file, selecting only sampleRatio of them * - * @param portableStream Stream of zip file - * @param zipfile File name is only used to construct the names of the entries - * @param sampleRatio What fraction of files is returned from zip + * @param stream Stream of zip file + * @param zipfile File name is only used to construct the names of the entries + * @param sampleRatio What fraction of files is returned from zip */ - class ZipIterator(portableStream: PortableDataStream, zipfile: String, sampleRatio: Double = 1) + class ZipIterator(stream: InputStream, zipfile: String, random: Random, sampleRatio: Double = 1) extends Iterator[(String, Array[Byte])] { - val stream = portableStream.open - private val zipstream = new ZipInputStream(stream) - - val random = { - val rd = new Random() - rd.setSeed(0) - rd - } + private val zipStream = new ZipInputStream(stream) private def getNext: Option[(String, Array[Byte])] = { - var entry = zipstream.getNextEntry + var entry = zipStream.getNextEntry while (entry != null) { if (!entry.isDirectory && random.nextDouble < sampleRatio) { @@ -59,7 +54,7 @@ object StreamUtilities { //extracting all bytes of a given entry val byteStream = new ByteArrayOutputStream - IOUtils.copy(zipstream, byteStream) + IOUtils.copy(zipStream, byteStream) val bytes = byteStream.toByteArray assert(bytes.length == entry.getSize, @@ -67,7 +62,7 @@ object StreamUtilities { return Some((filename, bytes)) } - entry = zipstream.getNextEntry + entry = zipStream.getNextEntry } stream.close() @@ -76,7 +71,7 @@ object StreamUtilities { private var nextValue = getNext - def hasNext: Boolean = !nextValue.isEmpty + def hasNext: Boolean = nextValue.isDefined def next: (String, Array[Byte]) = { val result = nextValue.get diff --git a/src/core/hadoop/src/main/scala/HadoopUtils.scala b/src/core/hadoop/src/main/scala/HadoopUtils.scala index cc707091371..c18a5d4a9a0 100644 --- a/src/core/hadoop/src/main/scala/HadoopUtils.scala +++ b/src/core/hadoop/src/main/scala/HadoopUtils.scala @@ -3,17 +3,9 @@ package com.microsoft.ml.spark.hadoop -import java.nio.file.Paths - -import org.apache.commons.io.FilenameUtils - -import scala.sys.process._ -import org.apache.hadoop.conf.{Configuration, Configured} -import org.apache.hadoop.fs.{Path, PathFilter} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.sql.SparkSession +import org.apache.hadoop.conf.Configuration import scala.language.existentials -import scala.util.Random +import scala.sys.process._ class HadoopUtils(hadoopConf: Configuration) { // Is there a better way? We need to deduce full Hadoop conf @@ -73,102 +65,3 @@ class HadoopUtils(hadoopConf: Configuration) { } } - -/** Filter that allows loading a fraction of HDFS files. */ -class SamplePathFilter extends Configured with PathFilter { - val random = { - val rd = new Random() - rd.setSeed(0) - rd - } - - // Ratio of files to be read from disk - var sampleRatio: Double = 1 - - // When inspectZip is enabled, zip files are treated as directories, and SamplePathFilter can't filter them out. - // Otherwise, zip files are treated as regular files and only sampleRatio of them is read. - var inspectZip: Boolean = true - - override def setConf(conf: Configuration): Unit = { - if (conf != null) { - sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1) - inspectZip = conf.getBoolean(SamplePathFilter.inspectZipParam, true) - } - } - - override def accept(path: Path): Boolean = { - // Note: checking fileSystem.isDirectory is very slow here, so we use basic rules instead - !SamplePathFilter.isFile(path) || - (SamplePathFilter.isZipFile(path) && inspectZip) || - random.nextDouble() < sampleRatio - } -} - -object SamplePathFilter { - val ratioParam = "sampleRatio" - val inspectZipParam = "inspectZip" - - def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) != "" - - def isZipFile(filename: String): Boolean = FilenameUtils.getExtension(filename) == "zip" - - def isZipFile(path: Path): Boolean = isZipFile(path.toString) - - /** Set/unset hdfs PathFilter - * - * @param value Filter class that is passed to HDFS - * @param sampleRatio Fraction of the files that the filter picks - * @param inspectZip Look into zip files, if true - * @param spark Existing Spark session - * @return - */ - def setPathFilter(value: Option[Class[_]], sampleRatio: Option[Double] = None, - inspectZip: Option[Boolean] = None, spark: SparkSession) - : Option[Class[_]] = { - val flagName = FileInputFormat.PATHFILTER_CLASS - val hadoopConf = spark.sparkContext.hadoopConfiguration - val old = Option(hadoopConf.getClass(flagName, null)) - if (sampleRatio.isDefined) { - hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio.get) - } else { - hadoopConf.unset(SamplePathFilter.ratioParam) - None - } - - if (inspectZip.isDefined) { - hadoopConf.setBoolean(SamplePathFilter.inspectZipParam, inspectZip.get) - } else { - hadoopConf.unset(SamplePathFilter.inspectZipParam) - None - } - - value match { - case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter]) - case None => hadoopConf.unset(flagName) - } - old - } -} - -object RecursiveFlag { - - /** Sets a value of spark recursive flag - * - * @param value value to set - * @param spark existing spark session - * @return previous value of this flag - */ - def setRecursiveFlag(value: Option[String], spark: SparkSession): Option[String] = { - val flagName = FileInputFormat.INPUT_DIR_RECURSIVE - val hadoopConf = spark.sparkContext.hadoopConfiguration - val old = Option(hadoopConf.get(flagName)) - - value match { - case Some(v) => hadoopConf.set(flagName, v) - case None => hadoopConf.unset(flagName) - } - - old - } - -} diff --git a/src/core/schema/src/main/scala/BinaryFileSchema.scala b/src/core/schema/src/main/scala/BinaryFileSchema.scala index 39cf8abe660..277011d0c16 100644 --- a/src/core/schema/src/main/scala/BinaryFileSchema.scala +++ b/src/core/schema/src/main/scala/BinaryFileSchema.scala @@ -4,16 +4,19 @@ package com.microsoft.ml.spark.schema import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types.{StructType, StructField, StringType, BinaryType} +import org.apache.spark.sql.types.{BinaryType, StringType, StructField, StructType} object BinaryFileSchema { /** Schema for the binary file column: Row(String, Array[Byte]) */ val columnSchema = StructType(Seq( - StructField("path", StringType, true), - StructField("bytes", BinaryType, true) //raw file bytes + StructField("path", StringType, true), + StructField("bytes", BinaryType, true) // raw file bytes )) + /** Schema for the binary file column: Row(String, Array[Byte]) */ + val schema = StructType(StructField("value", columnSchema, true) :: Nil) + def getPath(row: Row): String = row.getString(0) def getBytes(row: Row): Array[Byte] = row.getAs[Array[Byte]](1) diff --git a/src/core/schema/src/main/scala/ImageSchema.scala b/src/core/schema/src/main/scala/ImageSchema.scala index 1a993b84d03..4447f5de541 100644 --- a/src/core/schema/src/main/scala/ImageSchema.scala +++ b/src/core/schema/src/main/scala/ImageSchema.scala @@ -3,21 +3,21 @@ package com.microsoft.ml.spark.schema -import com.microsoft.ml.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types._ -import scala.reflect.ClassTag +import org.apache.spark.sql.{DataFrame, Row} object ImageSchema { /** Schema for the image column: Row(String, Int, Int, Int, Array[Byte]) */ val columnSchema = StructType( - StructField("path", StringType, true) :: - StructField("height", IntegerType, true) :: - StructField("width", IntegerType, true) :: - StructField("type", IntegerType, true) :: //OpenCV type: CV_8U in most cases - StructField("bytes", BinaryType, true) :: Nil) //OpenCV bytes: row-wise BGR in most cases + StructField("path", StringType, true) :: + StructField("height", IntegerType, true) :: + StructField("width", IntegerType, true) :: + StructField("type", IntegerType, true) :: //OpenCV type: CV_8U in most cases + StructField("bytes", BinaryType, true) :: Nil) //OpenCV bytes: row-wise BGR in most cases + + // single column of images named "image" + val schema = StructType(StructField("image", columnSchema, true) :: Nil) def getPath(row: Row): String = row.getString(0) def getHeight(row: Row): Int = row.getInt(1) @@ -34,17 +34,4 @@ object ImageSchema { def isImage(df: DataFrame, column: String): Boolean = df.schema(column).dataType == columnSchema - /** This object will load the openCV binaries when the object is referenced - * for the first time, subsequent references will not re-load the binaries. - * In spark, this loads one copy for each running jvm, instead of once per partition. - * This technique is similar to that used by the cntk_jni jar, - * but in the case where microsoft cannot edit the jar - */ - private[spark] object OpenCVLoader { - import org.opencv.core.Core - new NativeLoader("/nu/pattern/opencv").loadLibraryByName(Core.NATIVE_LIBRARY_NAME) - } - - private[spark] def loadOpenCV[T:ClassTag](rdd: RDD[T]):RDD[T] = - rdd.mapPartitions({it => OpenCVLoader; it}, preservesPartitioning = true) } diff --git a/src/core/test/base/src/main/scala/TestBase.scala b/src/core/test/base/src/main/scala/TestBase.scala index 26341033a97..ecb30d4a1a5 100644 --- a/src/core/test/base/src/main/scala/TestBase.scala +++ b/src/core/test/base/src/main/scala/TestBase.scala @@ -12,6 +12,7 @@ import org.apache.spark.ml.util.{MLReadable, MLWritable} import org.apache.spark.sql.{DataFrame, _} import org.apache.commons.io.FileUtils import org.apache.spark.ml.linalg.DenseVector +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.scalactic.{Equality, TolerantNumerics} import org.scalactic.source.Position import org.scalatest._ @@ -50,6 +51,8 @@ abstract class TestBase extends FunSuite with BeforeAndAfterEachTestData with Be } protected lazy val sc: SparkContext = session.sparkContext + protected lazy val ssc: StreamingContext = new StreamingContext(sc, Seconds(1)) + protected lazy val dir = SparkSessionFactory.workingDir private var tmpDirCreated = false diff --git a/src/featurize/src/main/scala/AssembleFeatures.scala b/src/featurize/src/main/scala/AssembleFeatures.scala index d0986c54d93..2d52e83328b 100644 --- a/src/featurize/src/main/scala/AssembleFeatures.scala +++ b/src/featurize/src/main/scala/AssembleFeatures.scala @@ -4,27 +4,26 @@ package com.microsoft.ml.spark import java.io._ -import java.sql.{Date, Time, Timestamp} +import java.sql.{Date, Timestamp} import java.time.temporal.ChronoField -import com.microsoft.ml.spark.schema.{CategoricalColumnInfo, DatasetExtensions, ImageSchema} import com.microsoft.ml.spark.schema.DatasetExtensions._ -import org.apache.hadoop.fs.Path +import com.microsoft.ml.spark.schema.{CategoricalColumnInfo, DatasetExtensions, ImageSchema} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.feature._ +import org.apache.spark.ml.linalg.SQLDataTypes.VectorType +import org.apache.spark.ml.linalg.{SparseVector, Vectors} import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.SQLDataTypes.VectorType -import org.apache.spark.ml.linalg.{SparseVector, Vectors} import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StringType, _} +import scala.collection.immutable.{BitSet, HashSet} import scala.collection.mutable import scala.collection.mutable.ListBuffer -import scala.collection.immutable.{BitSet, HashSet} import scala.reflect.runtime.universe.{TypeTag, typeTag} private object AssembleFeaturesUtilities diff --git a/src/image-featurizer/src/test/scala/ImageFeaturizerSuite.scala b/src/image-featurizer/src/test/scala/ImageFeaturizerSuite.scala index 2bac6588baa..c2fd02a05a1 100644 --- a/src/image-featurizer/src/test/scala/ImageFeaturizerSuite.scala +++ b/src/image-featurizer/src/test/scala/ImageFeaturizerSuite.scala @@ -5,16 +5,17 @@ package com.microsoft.ml.spark import java.net.URI -import org.apache.spark.sql.DataFrame import com.microsoft.ml.spark.FileUtilities.File -import org.apache.spark.ml.linalg.DenseVector import com.microsoft.ml.spark.Readers.implicits._ +import com.microsoft.ml.spark.schema.ImageSchema +import org.apache.commons.io.FileUtils import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.util.{MLReadable, MLWritable} +import org.apache.spark.sql.DataFrame +import org.apache.spark.image.ImageFileFormat -import scala.collection.JavaConversions._ - -class ImageFeaturizerSuite extends LinuxOnly with CNTKTestUtils with RoundTripTestBase { +class ImageFeaturizerSuite extends LinuxOnly with CNTKTestUtils with RoundTripTestBase with FileReaderUtils { val images: DataFrame = session.readImages(imagePath, true).withColumnRenamed("image", inputCol) val modelDir = new File(filesRoot, "CNTKModel") @@ -34,6 +35,33 @@ class ImageFeaturizerSuite extends LinuxOnly with CNTKTestUtils with RoundTripTe compareToTestModel(result) } + test("structured streaming"){ + + val model = new ImageFeaturizer() + .setInputCol("image") + .setOutputCol(outputCol) + .setModelLocation(session, s"${sys.env("DATASETS_HOME")}/CNTKModel/ConvNet_CIFAR10.model") + .setCutOutputLayers(0) + .setLayerNames(Array("z")) + + val imageDF = session + .readStream + .format(classOf[ImageFileFormat].getName) + .schema(ImageSchema.schema) + .load(cifarDirectory) + + val resultDF = model.transform(imageDF) + + val q1 = resultDF.writeStream + .format("memory") + .queryName("images") + .start() + + tryWithRetries(){ () => + assert(session.sql("select * from images").count() == 6) + } + } + def resNetModel(): ImageFeaturizer = new ImageFeaturizer() .setInputCol(inputCol) .setOutputCol(outputCol) @@ -44,6 +72,16 @@ class ImageFeaturizerSuite extends LinuxOnly with CNTKTestUtils with RoundTripTe compareToTestModel(result) } + test("the Image feature should work with the modelSchema + new images") { + val newImages = session.read + .format(classOf[ImageFileFormat].getName) + .load(cifarDirectory) + .withColumnRenamed("image","cntk_images") + + val result = resNetModel().setCutOutputLayers(0).transform(newImages) + compareToTestModel(result) + } + test("Image featurizer should work with ResNet50", TestBase.Extended) { val result = resNetModel().transform(images) val resVec = result.select(outputCol).collect()(0).getAs[DenseVector](0) diff --git a/src/image-transformer/src/main/scala/ImageTransformer.scala b/src/image-transformer/src/main/scala/ImageTransformer.scala index 9e8fab9ce05..158ea53c0b4 100644 --- a/src/image-transformer/src/main/scala/ImageTransformer.scala +++ b/src/image-transformer/src/main/scala/ImageTransformer.scala @@ -3,24 +3,17 @@ package com.microsoft.ml.spark +import com.microsoft.ml.spark.schema.{BinaryFileSchema, ImageSchema} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.DefaultParamsReadable +import org.apache.spark.ml.param.{ParamMap, _} +import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.ml.param._ -import com.microsoft.ml.spark.schema.ImageSchema +import org.opencv.core.{Core, Mat, Rect, Size} +import org.opencv.imgproc.Imgproc import scala.collection.mutable.ListBuffer -import com.microsoft.ml.spark.schema.BinaryFileSchema - -import scala.collection.mutable.{ListBuffer, WrappedArray} -import org.opencv.core.Core -import org.opencv.core.Mat -import org.opencv.core.{Rect, Size} -import org.opencv.imgproc.Imgproc -import org.apache.spark.ml.util.Identifiable /** Image processing stage. * @param params Map of parameters @@ -242,7 +235,7 @@ object ImageTransformer extends DefaultParamsReadable[ImageTransformer] { if (row == null) return None val decoded = if (decode) { - val path = BinaryFileSchema.getPath(row) + val path = BinaryFileSchema.getPath(row) val bytes = BinaryFileSchema.getBytes(row) //early return if the image can't be decompressed @@ -341,9 +334,7 @@ class ImageTransformer(val uid: String) extends Transformer val schema = dataset.toDF.schema - val loaded = ImageSchema.loadOpenCV(dataset.toDF.rdd) - - val df = spark.createDataFrame(loaded, schema) + val df = ImageReader.loadOpenCV(dataset.toDF) val isBinary = BinaryFileSchema.isBinaryFile(df, $(inputCol)) assert(ImageSchema.isImage(df, $(inputCol)) || isBinary, "input column should have Image or BinaryFile type") diff --git a/src/image-transformer/src/main/scala/UnrollImage.scala b/src/image-transformer/src/main/scala/UnrollImage.scala index 41ed8e2dea2..dc8afa7843d 100644 --- a/src/image-transformer/src/main/scala/UnrollImage.scala +++ b/src/image-transformer/src/main/scala/UnrollImage.scala @@ -7,7 +7,7 @@ import com.microsoft.ml.spark.schema.ImageSchema._ import org.apache.spark.ml.Transformer import org.apache.spark.ml.linalg.DenseVector import org.apache.spark.ml.linalg.SQLDataTypes.VectorType -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types._ diff --git a/src/readers/src/main/python/BinaryFileReader.py b/src/readers/src/main/python/BinaryFileReader.py index 60254451e7f..364b69aa057 100644 --- a/src/readers/src/main/python/BinaryFileReader.py +++ b/src/readers/src/main/python/BinaryFileReader.py @@ -29,7 +29,7 @@ bytes """ -def readBinaryFiles(self, path, recursive = False, sampleRatio = 1.0, inspectZip = True): +def readBinaryFiles(self, path, recursive = False, sampleRatio = 1.0, inspectZip = True, seed=0): """ Reads the directory of binary files from the local or remote (WASB) source This function is attached to SparkSession class. @@ -50,11 +50,36 @@ def readBinaryFiles(self, path, recursive = False, sampleRatio = 1.0, inspectZip reader = ctx._jvm.com.microsoft.ml.spark.BinaryFileReader sql_ctx = pyspark.SQLContext.getOrCreate(ctx) jsession = sql_ctx.sparkSession._jsparkSession - jresult = reader.read(path, recursive, jsession, float(sampleRatio), inspectZip) + jresult = reader.read(path, recursive, jsession, float(sampleRatio), inspectZip, seed) return DataFrame(jresult, sql_ctx) setattr(sql.SparkSession, 'readBinaryFiles', classmethod(readBinaryFiles)) +def streamBinaryFiles(self, path, sampleRatio = 1.0, inspectZip = True, seed=0): + """ + Streams the directory of binary files from the local or remote (WASB) source + This function is attached to SparkSession class. + + :Example: + + >>> spark.streamBinaryFiles(path, sampleRatio = 1.0, inspectZip = True) + + Args: + path (str): Path to the file directory + + Returns: + DataFrame: DataFrame with a single column "value"; see binaryFileSchema for details + + """ + ctx = SparkContext.getOrCreate() + reader = ctx._jvm.com.microsoft.ml.spark.BinaryFileReader + sql_ctx = pyspark.SQLContext.getOrCreate(ctx) + jsession = sql_ctx.sparkSession._jsparkSession + jresult = reader.stream(path, jsession, float(sampleRatio), inspectZip, seed) + return DataFrame(jresult, sql_ctx) + +setattr(sql.SparkSession, 'streamBinaryFiles', classmethod(streamBinaryFiles)) + def isBinaryFile(df, column): """ Returns True if the column contains binary files diff --git a/src/readers/src/main/python/ImageReader.py b/src/readers/src/main/python/ImageReader.py index e8ae31d95e1..fd2d476a3ff 100644 --- a/src/readers/src/main/python/ImageReader.py +++ b/src/readers/src/main/python/ImageReader.py @@ -13,7 +13,7 @@ from pyspark.sql import DataFrame -def readImages(sparkSession, path, recursive = False, sampleRatio = 1.0, inspectZip = True): +def readImages(sparkSession, path, recursive = False, sampleRatio = 1.0, inspectZip = True, seed = 0): """ Reads the directory of images from the local or remote (WASB) source. This function is attached to SparkSession class. @@ -33,11 +33,36 @@ def readImages(sparkSession, path, recursive = False, sampleRatio = 1.0, inspect reader = ctx._jvm.com.microsoft.ml.spark.ImageReader sql_ctx = pyspark.SQLContext.getOrCreate(ctx) jsession = sql_ctx.sparkSession._jsparkSession - jresult = reader.read(path, recursive, jsession, float(sampleRatio), inspectZip) + jresult = reader.read(path, recursive, jsession, float(sampleRatio), inspectZip, seed) return DataFrame(jresult, sql_ctx) setattr(sql.SparkSession, 'readImages', classmethod(readImages)) +def streamImages(sparkSession, path, sampleRatio = 1.0, inspectZip = True, seed = 0): + """ + Reads the directory of images from the local or remote (WASB) source. + This function is attached to SparkSession class. + Example: spark.streamImages(path, .5, ...) + + Args: + sparkSession (SparkSession): Existing sparkSession + path (str): Path to the image directory + sampleRatio (double): Fraction of the images loaded + inspectZip: (boolean): Whether to look inside zip folders + + Returns: + DataFrame: DataFrame with a single column of "images", see imageSchema + for details + """ + ctx = SparkContext.getOrCreate() + reader = ctx._jvm.com.microsoft.ml.spark.ImageReader + sql_ctx = pyspark.SQLContext.getOrCreate(ctx) + jsession = sql_ctx.sparkSession._jsparkSession + jresult = reader.stream(path, jsession, float(sampleRatio), inspectZip, seed) + return DataFrame(jresult, sql_ctx) + +setattr(sql.SparkSession, 'streamImages', classmethod(streamImages)) + def isImage(df, column): """ Returns True if the column contains images diff --git a/src/readers/src/main/scala/BinaryFileFormat.scala b/src/readers/src/main/scala/BinaryFileFormat.scala new file mode 100644 index 00000000000..56515cee7f5 --- /dev/null +++ b/src/readers/src/main/scala/BinaryFileFormat.scala @@ -0,0 +1,208 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package org.apache.spark.binary + +import java.io.{Closeable, InputStream} +import java.net.URI + +import com.microsoft.ml.spark.StreamUtilities.ZipIterator +import com.microsoft.ml.spark.schema.BinaryFileSchema +import org.apache.commons.io.{FilenameUtils, IOUtils} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.log4j.Logger +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration + +import scala.util.Random + +/** Actually reads the records from files + * + * @param subsample what ratio to subsample + * @param inspectZip whether to inspect zip files + */ +private[spark] class BinaryRecordReader(val subsample: Double, val inspectZip: Boolean, val seed: Long) + extends RecordReader[String, BytesWritable] { + + private var done: Boolean = false + private var inputStream: InputStream = _ + private var filename: String = _ + private var recordValue: BytesWritable = _ + private var progress: Float = 0.0F + private val rng: Random = new Random() + private var zipIterator: ZipIterator = _ + + override def close(): Unit = { + if (inputStream != null) { + inputStream.close() + } + } + + override def getCurrentKey: String = { + filename + } + + override def getCurrentValue: BytesWritable = { + recordValue + } + + override def getProgress: Float = { + progress + } + + override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = { + // the file input + val fileSplit = inputSplit.asInstanceOf[FileSplit] + + val file = fileSplit.getPath // the actual file we will be reading from + val conf = context.getConfiguration // job configuration + val fs = file.getFileSystem(conf) // get the filesystem + filename = file.toString // open the File + + inputStream = fs.open(file) + rng.setSeed(filename.hashCode.toLong ^ seed) + if (inspectZip && FilenameUtils.getExtension(filename) == "zip") { + zipIterator = new ZipIterator(inputStream, filename, rng, subsample) + } + } + + def markAsDone(): Unit = { + done = true + progress = 1.0F + } + + override def nextKeyValue(): Boolean = { + if (done) { + return false + } + + if (zipIterator != null) { + if (zipIterator.hasNext) { + val (fn, barr) = zipIterator.next + filename = fn + recordValue = new BytesWritable() + recordValue.set(barr, 0, barr.length) + true + } else { + markAsDone() + false + } + } else { + if (rng.nextDouble() <= subsample) { + val barr = IOUtils.toByteArray(inputStream) + recordValue = new BytesWritable() + recordValue.set(barr, 0, barr.length) + markAsDone() + true + } else { + markAsDone() + false + } + } + } +} + +/** File format used for structured streaming of binary files */ +class BinaryFileFormat extends TextBasedFileFormat with DataSourceRegister { + + override def isSplitable(sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = false + + override def shortName(): String = "binary" + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + if (files.isEmpty) { + None + } else { + Some(BinaryFileSchema.schema) + } + } + + override def prepareWrite(sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new NotImplementedError("writing to binary files is not supported") + } + + override def buildReader(sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val subsample = options.getOrElse("subsample", "1.0").toDouble + val inspectZip = options.getOrElse("inspectZip", "false").toBoolean + val seed = options.getOrElse("seed", "0").toLong + + assert(subsample >= 0.0 & subsample <= 1.0) + (file: PartitionedFile) => { + val fileReader = new HadoopFileReader(file, broadcastedHadoopConf.value.value, subsample, inspectZip, seed) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => fileReader.close())) + fileReader.map { bytes => + val row = new GenericInternalRow(2) + row.update(0, UTF8String.fromString(file.filePath)) + row.update(1, bytes.getBytes) + val outerRow = new GenericInternalRow(1) + outerRow.update(0,row) + outerRow + } + } + } + + override def toString: String = "Binary" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[BinaryFileFormat] +} + +/** Thin wrapper class analogous to others in the spark ecosystem */ +private[spark] class HadoopFileReader(file: PartitionedFile, + conf: Configuration, + subsample: Double, + inspectZip: Boolean, + seed: Long) + extends Iterator[BytesWritable] with Closeable { + + private val iterator = { + val fileSplit = new FileSplit( + new Path(new URI(file.filePath)), + file.start, + file.length, + Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + val reader = new BinaryRecordReader(subsample, inspectZip, seed) + reader.initialize(fileSplit, hadoopAttemptContext) + new RecordReaderIterator(reader) + } + + override def hasNext: Boolean = iterator.hasNext + + override def next(): BytesWritable = iterator.next() + + override def close(): Unit = iterator.close() + +} diff --git a/src/readers/src/main/scala/BinaryFileReader.scala b/src/readers/src/main/scala/BinaryFileReader.scala index b3408ecf7e5..754c2f07bc4 100644 --- a/src/readers/src/main/scala/BinaryFileReader.scala +++ b/src/readers/src/main/scala/BinaryFileReader.scala @@ -4,74 +4,62 @@ package com.microsoft.ml.spark import com.microsoft.ml.spark.schema.BinaryFileSchema -import org.apache.spark.input.PortableDataStream -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.spark.binary.BinaryFileFormat +import org.apache.spark.sql.{DataFrame, SparkSession} import scala.language.existentials -import com.microsoft.ml.spark.StreamUtilities.{ZipIterator} -import com.microsoft.ml.spark.hadoop.{SamplePathFilter, RecursiveFlag} object BinaryFileReader { - //single column of images named "image" - private val binaryDFSchema = StructType(StructField("value", BinaryFileSchema.columnSchema, true) :: Nil) + private def recursePath(fileSystem: FileSystem, + path: Path, + pathFilter:FileStatus => Boolean, + visitedSymlinks: Set[Path]): Array[Path] ={ + val filteredPaths = fileSystem.listStatus(path).filter(pathFilter) + val filteredDirs = filteredPaths.filter(fs => fs.isDirectory & !visitedSymlinks(fs.getPath)) + val symlinksFound = visitedSymlinks ++ filteredDirs.filter(_.isSymlink).map(_.getPath) + filteredPaths.map(_.getPath) ++ filteredDirs.map(_.getPath) + .flatMap(p => recursePath(fileSystem, p, pathFilter, symlinksFound)) + } + + def recursePath(fileSystem: FileSystem, path: Path, pathFilter:FileStatus => Boolean): Array[Path] ={ + recursePath(fileSystem, path, pathFilter, Set()) + } - /** Read the directory of images from the local or remote source + /** Read the directory of binary files from the local or remote source * - * @param path Path to the image directory - * @param recursive Recursive search flag - * @return Dataframe with a single column of "images", see imageSchema for details + * @param path Path to the directory + * @param recursive Recursive search flag + * @return DataFrame with a single column of "binaryFiles", see "columnSchema" for details */ - private[spark] def readRDD(path: String, recursive: Boolean, spark: SparkSession, - sampleRatio: Double, inspectZip: Boolean) - : RDD[(String, Array[Byte])] = { - - require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be between 0 and 1") - - val oldRecursiveFlag = RecursiveFlag.setRecursiveFlag(Some(recursive.toString), spark) - val oldPathFilter: Option[Class[_]] = - if (sampleRatio < 1) - SamplePathFilter.setPathFilter(Some(classOf[SamplePathFilter]), Some(sampleRatio), Some(inspectZip), spark) - else - None - - var data: RDD[(String, Array[Byte])] = null - try { - val streams = spark.sparkContext.binaryFiles(path, spark.sparkContext.defaultParallelism) - .repartition(spark.sparkContext.defaultParallelism) - - // Create files RDD and load bytes - data = if (!inspectZip) { - streams.mapValues((stream: PortableDataStream) => stream.toArray) - } else { - // if inspectZip is enabled, examine/sample the contents of zip files - streams.flatMap({ case (filename: String, stream: PortableDataStream) => - if (SamplePathFilter.isZipFile(filename)) { - new ZipIterator(stream, filename, sampleRatio) - } else { - Some((filename, stream.toArray)) - } - }) - } - } - finally { - // return Hadoop flag to its original value - RecursiveFlag.setRecursiveFlag(oldRecursiveFlag, spark = spark) - SamplePathFilter.setPathFilter(oldPathFilter, spark = spark) - () + def read(path: String, recursive: Boolean, spark: SparkSession, + sampleRatio: Double = 1, inspectZip: Boolean = true, seed: Long = 0L): DataFrame = { + val p = new Path(path) + val globs = if (recursive){ + recursePath(p.getFileSystem(spark.sparkContext.hadoopConfiguration), p, {fs => fs.isDirectory}) + .map(g => g) ++ Array(p) + }else{ + Array(p) } - - data + spark.read.format(classOf[BinaryFileFormat].getName) + .option("subsample", sampleRatio) + .option("seed", seed) + .option("inspectZip",inspectZip).load(globs.map(g => g.toString):_*) } - def read(path: String, recursive: Boolean, spark: SparkSession, - sampleRatio: Double = 1, inspectZip: Boolean = true): DataFrame = { - val rowRDD = readRDD(path, recursive, spark, sampleRatio, inspectZip) - .map({row:(String, Array[Byte]) => Row(Row(row._1, row._2))}) - - spark.createDataFrame(rowRDD, binaryDFSchema) + /** Read the directory of binary files from the local or remote source + * + * @param path Path to the directory + * @return DataFrame with a single column of "binaryFiles", see "columnSchema" for details + */ + def stream(path: String, spark: SparkSession, + sampleRatio: Double = 1, inspectZip: Boolean = true, seed: Long = 0L): DataFrame = { + val p = new Path(path) + spark.readStream.format(classOf[BinaryFileFormat].getName) + .option("subsample", sampleRatio) + .option("seed", seed) + .option("inspectZip",inspectZip).schema(BinaryFileSchema.schema).load(p.toString) } -} +} diff --git a/src/readers/src/main/scala/ImageFileFormat.scala b/src/readers/src/main/scala/ImageFileFormat.scala new file mode 100644 index 00000000000..e6860fc4211 --- /dev/null +++ b/src/readers/src/main/scala/ImageFileFormat.scala @@ -0,0 +1,90 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package org.apache.spark.image + +import com.microsoft.ml.spark.ImageReader +import com.microsoft.ml.spark.schema.ImageSchema +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce._ +import org.apache.spark.TaskContext +import org.apache.spark.binary._ +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.SerializableConfiguration + +class ImageFileFormat extends TextBasedFileFormat with DataSourceRegister with Serializable { + + override def isSplitable(sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = false + + override def shortName(): String = "image" + + override def inferSchema(sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + Some(ImageSchema.schema) + } + + override def prepareWrite(sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new NotImplementedError("writing to image files is not supported") + } + + override def buildReader(sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val subsample = options.getOrElse("subsample","1.0").toDouble + assert(subsample>=0.0 & subsample <=1.0) + val inspectZip = options.getOrElse("inspectZip", "false").toBoolean + val seed = options.getOrElse("seed", "0").toLong + + (file: PartitionedFile) => { + val fileReader = new HadoopFileReader(file, broadcastedHadoopConf.value.value, subsample, inspectZip, seed) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => fileReader.close())) + fileReader.flatMap {bytes => + val byteArray = bytes.getBytes + ImageReader.OpenCVLoader + val rowOpt = ImageReader.decode(file.filePath, byteArray) + + rowOpt match { + case None => None + case Some(row) => + val imGenRow = new GenericInternalRow(1) + val genRow = new GenericInternalRow(ImageSchema.columnSchema.fields.length) + genRow.update(0, UTF8String.fromString(row.getString(0))) + genRow.update(1, row.getInt(1)) + genRow.update(2, row.getInt(2)) + genRow.update(3, row.getInt(3)) + genRow.update(4, row.getAs[Array[Byte]](4)) + imGenRow.update(0, genRow) + Some(imGenRow) + } + } + } + } + + override def toString: String = "Image" + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[BinaryFileFormat] + +} diff --git a/src/readers/src/main/scala/ImageReader.scala b/src/readers/src/main/scala/ImageReader.scala index 79615de3240..fdffeec775e 100644 --- a/src/readers/src/main/scala/ImageReader.scala +++ b/src/readers/src/main/scala/ImageReader.scala @@ -3,16 +3,37 @@ package com.microsoft.ml.spark +import com.microsoft.ml.spark.BinaryFileReader.recursePath import com.microsoft.ml.spark.schema.ImageSchema -import org.apache.spark.sql.types._ +import org.apache.hadoop.fs.Path +import org.apache.spark.image.ImageFileFormat +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.opencv.core.{Core, CvException, MatOfByte} +import org.opencv.core.{CvException, MatOfByte} import org.opencv.imgcodecs.Imgcodecs object ImageReader { - //single column of images named "image" - private val imageDFSchema = StructType(StructField("image", ImageSchema.columnSchema, true) :: Nil) + /** This object will load the openCV binaries when the object is referenced + * for the first time, subsequent references will not re-load the binaries. + * In spark, this loads one copy for each running jvm, instead of once per partition. + * This technique is similar to that used by the cntk_jni jar, + * but in the case where microsoft cannot edit the jar + */ + object OpenCVLoader { + import org.opencv.core.Core + new NativeLoader("/nu/pattern/opencv").loadLibraryByName(Core.NATIVE_LIBRARY_NAME) + } + + private[spark] def loadOpenCVFunc[A](it: Iterator[A]) = { + OpenCVLoader + it + } + + private[spark] def loadOpenCV(df: DataFrame):DataFrame ={ + val encoder = RowEncoder(df.schema) + df.mapPartitions(loadOpenCVFunc)(encoder) + } /** Convert the image from compressd (jpeg, etc.) into OpenCV representation and store it in Row * See ImageSchema for details. @@ -21,7 +42,7 @@ object ImageReader { * @param bytes image bytes (for example, jpeg) * @return returns None if decompression fails */ - private[spark] def decode(filename: String, bytes: Array[Byte]): Option[Row] = { + def decode(filename: String, bytes: Array[Byte]): Option[Row] = { val mat = new MatOfByte(bytes: _*) val decodedOpt = try { Some(Imgcodecs.imdecode(mat, Imgcodecs.CV_LOAD_IMAGE_COLOR)) @@ -45,20 +66,35 @@ object ImageReader { * * @param path Path to the image directory * @param recursive Recursive search flag - * @return Dataframe with a single column of "images", see imageSchema for details + * @return DataFrame with a single column of "images", see "columnSchema" for details */ def read(path: String, recursive: Boolean, spark: SparkSession, - sampleRatio: Double = 1, inspectZip: Boolean = true): DataFrame = { - - val binaryRDD = BinaryFileReader.readRDD(path, recursive, spark, sampleRatio, inspectZip) - val binaryRDDlib = ImageSchema.loadOpenCV(binaryRDD) - - val validImages = binaryRDDlib.flatMap { - case (filename, bytes) => { - decode(filename, bytes).map(x => Row(x)) - } + sampleRatio: Double = 1, inspectZip: Boolean = true, seed: Long = 0L): DataFrame = { + val p = new Path(path) + val globs = if (recursive){ + recursePath(p.getFileSystem(spark.sparkContext.hadoopConfiguration), p, {fs => fs.isDirectory}) + .map(g => g) ++ Array(p) + }else{ + Array(p) } + spark.read.format(classOf[ImageFileFormat].getName) + .option("subsample", sampleRatio) + .option("seed", seed) + .option("inspectZip", inspectZip).load(globs.map(_.toString):_*) + } - spark.createDataFrame(validImages, imageDFSchema) + /** Read the directory of image files from the local or remote source + * + * @param path Path to the directory + * @return DataFrame with a single column of "imageFiles", see "columnSchema" for details + */ + def stream(path: String, spark: SparkSession, + sampleRatio: Double = 1, inspectZip: Boolean = true, seed: Long = 0L): DataFrame = { + val p = new Path(path) + spark.readStream.format(classOf[ImageFileFormat].getName) + .option("subsample", sampleRatio) + .option("seed", seed) + .option("inspectZip",inspectZip).schema(ImageSchema.schema).load(p.toString) } + } diff --git a/src/readers/src/main/scala/Readers.scala b/src/readers/src/main/scala/Readers.scala index d45189fcd7c..0bb0c6bcc1e 100644 --- a/src/readers/src/main/scala/Readers.scala +++ b/src/readers/src/main/scala/Readers.scala @@ -24,8 +24,8 @@ object Readers { * @return Dataframe with a single column "value" of binary files, see BinaryFileSchema for details */ def readBinaryFiles(path: String, recursive: Boolean, - sampleRatio: Double = 1, inspectZip: Boolean = true): DataFrame = - BinaryFileReader.read(path, recursive, sparkSession, sampleRatio, inspectZip) + sampleRatio: Double = 1, inspectZip: Boolean = true, seed: Long = 0L): DataFrame = + BinaryFileReader.read(path, recursive, sparkSession, sampleRatio, inspectZip, seed) /** Read the directory of images from the local or remote source * @@ -36,8 +36,8 @@ object Readers { * @return Dataframe with a single column "image" of images, see ImageSchema for details */ def readImages(path: String, recursive: Boolean, - sampleRatio: Double = 1, inspectZip: Boolean = true): DataFrame = - ImageReader.read(path, recursive, sparkSession, sampleRatio, inspectZip) + sampleRatio: Double = 1, inspectZip: Boolean = true, seed: Long = 0L): DataFrame = + ImageReader.read(path, recursive, sparkSession, sampleRatio, inspectZip, seed) } implicit def ImplicitSession(sparkSession: SparkSession):Session = new Session(sparkSession) diff --git a/src/readers/src/test/scala/BinaryFileReaderSuite.scala b/src/readers/src/test/scala/BinaryFileReaderSuite.scala index d9466038c75..83460a355ab 100644 --- a/src/readers/src/test/scala/BinaryFileReaderSuite.scala +++ b/src/readers/src/test/scala/BinaryFileReaderSuite.scala @@ -3,26 +3,56 @@ package com.microsoft.ml.spark -import com.microsoft.ml.spark.FileReaderSuiteUtils._ +import java.io.FileOutputStream + import com.microsoft.ml.spark.Readers.implicits._ import com.microsoft.ml.spark.schema.BinaryFileSchema.isBinaryFile +import com.microsoft.ml.spark.FileUtilities.{File, zipFolder} +import com.microsoft.ml.spark.schema.BinaryFileSchema +import org.apache.commons.io.FileUtils +import org.apache.spark.binary.BinaryFileFormat -class BinaryFileReaderSuite extends TestBase { +trait FileReaderUtils { + val fileLocation = s"${sys.env("DATASETS_HOME")}" + val imagesDirectory: String = fileLocation + "/Images" + val groceriesDirectory: String = imagesDirectory + "/Grocery" + val cifarDirectory: String = imagesDirectory + "/CIFAR" - test("binary dataframe") { + def createZip(directory: String): Unit ={ + val dir = new File(directory) + val zipfile = new File(directory + ".zip") + if (!zipfile.exists()) zipFolder(dir, zipfile) + } - val data = session.readBinaryFiles(groceriesDirectory, recursive = true) + def createZips(): Unit ={ + createZip(groceriesDirectory) + createZip(cifarDirectory) + } - println(time { data.count }) + def tryWithRetries[T](times: Array[Int] = Array(0, 100, 500, 1000, 3000, 5000))(block: () => T): T = { + for ((t, i) <- times.zipWithIndex){ + try{ + return block() + } catch { + case _: Exception if (i + 1) < times.length => + Thread.sleep(t.toLong) + } + } + throw new RuntimeException("This error should not occur, bug has been introduced in tryWithRetries") + } +} - assert(isBinaryFile(data, "value")) +class BinaryFileReaderSuite extends TestBase with FileReaderUtils { + test("binary dataframe") { + val data = session.readBinaryFiles(groceriesDirectory, recursive = true) + println(time { data.count }) + assert(isBinaryFile(data, "value")) val paths = data.select("value.path") //make sure that SQL has access to the sub-fields assert(paths.count == 31) //note that text file is also included } test("sample ratio test") { - val all = session.readBinaryFiles(groceriesDirectory, recursive = true, sampleRatio = 1.0) val sampled = session.readBinaryFiles(groceriesDirectory, recursive = true, sampleRatio = 0.5) val count = sampled.count @@ -31,7 +61,7 @@ class BinaryFileReaderSuite extends TestBase { test("with zip file") { /* remove when datasets/Images is updated */ - createZips + createZips() val images = session.readBinaryFiles(imagesDirectory, recursive = true) assert(images.count == 74) @@ -40,4 +70,60 @@ class BinaryFileReaderSuite extends TestBase { assert(images1.count == 39) } + test("handle folders with spaces") { + /* remove when datasets/Images is updated */ + val newDirTop = new File(tmpDir.toFile, "foo bar") + val newDirMid = new File(newDirTop, "fooey barey") + FileUtils.forceMkdir(newDirMid) + try { + val fos = new FileOutputStream(new File(newDirMid, "foo.txt")) + try { + fos.write((1 to 10).map(_.toByte).toArray) + } finally { + fos.close() + } + val files = session.readBinaryFiles(newDirTop.getAbsolutePath, recursive = true) + assert(files.count == 1) + } finally { + FileUtils.forceDelete(newDirTop) + } + } + + test("binary files should allow recursion"){ + val df = session + .read + .format(classOf[BinaryFileFormat].getName) + .load(groceriesDirectory + "**/*") + assert(df.count()==31) + df.printSchema() + } + + test("static load with new reader"){ + val df = session + .read + .format(classOf[BinaryFileFormat].getName) + .option("subsample", .5) + .load(cifarDirectory) + assert(df.count()==3) + } + + test("structured streaming with binary files"){ + val imageDF = session + .readStream + .format(classOf[BinaryFileFormat].getName) + .schema(BinaryFileSchema.schema) + .load(cifarDirectory) + + val q1 = imageDF.writeStream + .format("memory") + .queryName("images") + .start() + + tryWithRetries(){ () => + val df = session.sql("select * from images") + assert(df.count() == 6) + } + q1.stop() + } + } diff --git a/src/readers/src/test/scala/ImageReaderSuite.scala b/src/readers/src/test/scala/ImageReaderSuite.scala index 41f56350756..b220636e3aa 100644 --- a/src/readers/src/test/scala/ImageReaderSuite.scala +++ b/src/readers/src/test/scala/ImageReaderSuite.scala @@ -3,54 +3,57 @@ package com.microsoft.ml.spark -import com.microsoft.ml.spark.FileUtilities._ import com.microsoft.ml.spark.Readers.implicits._ -import com.microsoft.ml.spark.schema.ImageSchema.isImage - -object FileReaderSuiteUtils { - val fileLocation = s"${sys.env("DATASETS_HOME")}" - val imagesDirectory = fileLocation + "/Images" - val groceriesDirectory = imagesDirectory + "/Grocery" - val cifarDirectory = imagesDirectory + "/CIFAR" - - def createZip(directory: String): Unit ={ - val dir = new File(directory) - val zipfile = new File(directory + ".zip") - if (!zipfile.exists()) zipFolder(dir, zipfile) - } - - def createZips(): Unit ={ - createZip(groceriesDirectory) - createZip(cifarDirectory) - } -} +import com.microsoft.ml.spark.schema.ImageSchema +import org.apache.spark.image.ImageFileFormat -import com.microsoft.ml.spark.FileReaderSuiteUtils._ - -class ImageReaderSuite extends TestBase { +class ImageReaderSuite extends TestBase with FileReaderUtils { test("image dataframe") { - val images = session.readImages(groceriesDirectory, recursive = true) - println(time { images.count }) - - assert(isImage(images, "image")) // make sure the column "images" exists and has the right type - + assert(ImageSchema.isImage(images, "image")) // make sure the column "images" exists and has the right type val paths = images.select("image.path") //make sure that SQL has access to the sub-fields assert(paths.count == 30) - val areas = images.select(images("image.width") * images("image.height")) //more complicated SQL statement - println(s" area of image 1 ${areas.take(1)(0)}") } + test("read images with subsample"){ + val imageDF = session + .read + .format(classOf[ImageFileFormat].getName) + .option("subsample", .5) + .load(cifarDirectory) + assert(imageDF.count() == 3) + } + + test("structured streaming with images"){ + val schema = ImageSchema.schema + val imageDF = session + .readStream + .format(classOf[ImageFileFormat].getName) + .schema(schema) + .load(cifarDirectory) + + val q1 = imageDF.select("image.path").writeStream + .format("memory") + .queryName("images") + .start() + + tryWithRetries() {() => + val df = session.sql("select * from images") + assert(df.count() == 6) + } + q1.stop() + } + test("with zip file") { /* remove when datasets/Images is updated */ - createZips + createZips() val images = session.readImages(imagesDirectory, recursive = true) - assert(isImage(images, "image")) + assert(ImageSchema.isImage(images, "image")) assert(images.count == 72) val images1 = session.readImages(imagesDirectory, recursive = true, inspectZip = false)